mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fused RMSNorm Housekeeping (#159317)
Small PR to address comments that were made from the original fused rmsnorm PR that were not landed Changes: - Warning message when input.dtype doesn't match weight.dtype - Ensure default epsilon value is correct Comments: https://github.com/pytorch/pytorch/pull/153666#discussion_r2114735005 https://github.com/pytorch/pytorch/pull/153666#discussion_r2223518064 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159317 Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4619f0272
commit
dc286aef61
@ -342,8 +342,8 @@ Tensor rms_norm_symint(
|
||||
|
||||
if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Mismatch dtype between input and module: input dtype = ", input.dtype(),
|
||||
", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation"
|
||||
"Mismatch dtype between input and weight: input dtype = ", input.dtype(),
|
||||
", weight dtype = ", weight_opt.value().dtype(), ", Cannot dispatch to fused implementation."
|
||||
);
|
||||
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
|
@ -8748,6 +8748,30 @@ class TestNNDeviceType(NNTestCase):
|
||||
|
||||
self.assertEqual(Y_ref, Y)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
|
||||
def test_rmsnorm_epsilon(self, device, dtype):
|
||||
def rms_norm_reference_fn(i, normalized_shape):
|
||||
eps = torch.finfo(i.dtype).eps
|
||||
ndim = i.ndim
|
||||
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
|
||||
if i.dtype is not torch.float64:
|
||||
upcasted_i = i.float()
|
||||
else:
|
||||
upcasted_i = i
|
||||
result = upcasted_i * torch.rsqrt(
|
||||
upcasted_i.pow(2).mean(dim=dims, keepdim=True) + eps
|
||||
)
|
||||
return result.type_as(i)
|
||||
|
||||
shape = (2, 2)
|
||||
X = torch.tensor([[1e-12, -1e-12], [1e-12, -1e-12]], dtype=dtype, device=device)
|
||||
|
||||
Y = torch.nn.functional.rms_norm(X, shape)
|
||||
Y_ref = rms_norm_reference_fn(X, shape)
|
||||
|
||||
self.assertEqual(Y_ref, Y)
|
||||
|
||||
@onlyCPU
|
||||
def test_glu_bfloat16(self, device):
|
||||
def test_dtype(fn, input, dtype):
|
||||
|
Reference in New Issue
Block a user