mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add RMSNorm module (#121364)
Similar todbeed9724b/torchmultimodal/modules/layers/normalizations.py (L51)
**The implementation here is not optimized and we welcome pull requests to improve this** - Use `normalized_shape` instead of singular integer `dim` to be aligned with the `nn.LayerNorm` implementation - Remove the [upcast to float and downcast ](dbeed9724b/torchmultimodal/modules/layers/normalizations.py (L73)
) Differential Revision: [](https://our.internmc.facebook.com/intern/diff/) Differential Revision: [D55485840](https://our.internmc.facebook.com/intern/diff/D55485840) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121364 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3243be7c3a
commit
487b6d40ec
@ -4287,6 +4287,7 @@ class TestFunctionalTracing(JitTestCase):
|
||||
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"layer_norm": ARG_TYPE_MISMATCH,
|
||||
"rms_norm": ARG_TYPE_MISMATCH,
|
||||
"lp_pool1d": ARG_TYPE_MISMATCH,
|
||||
|
||||
"affine_grid": CONTROL_FLOW,
|
||||
|
Reference in New Issue
Block a user