Add RMSNorm module (#121364)

Similar to dbeed9724b/torchmultimodal/modules/layers/normalizations.py (L51)

**The implementation here is not optimized and we welcome pull requests to improve this**

- Use `normalized_shape` instead of singular integer `dim` to be aligned with the `nn.LayerNorm` implementation
- Remove the [upcast to float and downcast
](dbeed9724b/torchmultimodal/modules/layers/normalizations.py (L73))

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D55485840](https://our.internmc.facebook.com/intern/diff/D55485840)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121364
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-03-25 12:58:08 -07:00
committed by PyTorch MergeBot
parent 3243be7c3a
commit 487b6d40ec
16 changed files with 307 additions and 5 deletions

View File

@ -4287,6 +4287,7 @@ class TestFunctionalTracing(JitTestCase):
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
"layer_norm": ARG_TYPE_MISMATCH,
"rms_norm": ARG_TYPE_MISMATCH,
"lp_pool1d": ARG_TYPE_MISMATCH,
"affine_grid": CONTROL_FLOW,