mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Batch Norm Consolidation (#116092)
**Summary:** This commit simplifies the existing decomposition hierarchy of batch norm ops by adding a single, backend agnostic op: `batch_norm_with_update`. The existing hierarchy looks like: ``` aten.batch_norm -> aten._batch_norm_impl_index -> [ aten.native_batch_norm -> aten._native_batch_norm_legit (export only) -> _batch_norm_legit_cpu/cuda (kernels, export only) -> _batch_norm_cpu/cuda (kernels) ] OR [ aten.cudnn_batch_norm ] OR [ aten.miopen_batch_norm ] ``` Aside from complexity, an important problem with the above decomposition hierarchy is cuda numerics in export flows. We observed significantly worse convergence when training a mobilenetv2-like model when using the `_batch_norm_cuda` kernel instead of the `cudnn_batch_norm` kernel. This means users who export their models on CPU first then move the models to cuda later may silently see worse accuracies even when cudnn is installed, because they are using the worse kernel. This issue is summarized in https://github.com/pytorch/pytorch/issues/111384. Instead, the new hierarchy proposed by consolidating existing batch norm ops will look like: ``` aten.batch_norm -> aten.batch_norm_with_update -> [ _batch_norm_cpu (kernel) ] OR [ _batch_norm_cuda (kernel) ] OR [ cudnn_batch_norm (kernel) ] OR [ miopen_batch_norm (kernel) ] ``` The new op `batch_norm_with_update` hides backend implementation details and automatically picks the right kernel based on what is installed. This commit also adds the following variants to this op: ``` batch_norm_with_update_functional batch_norm_with_update.out batch_norm_no_update batch_norm_no_update.out batch_norm_backward ``` Note that this commit only adds this op and its variants, but does not actually change the decomps to produce these ops in the graph. This will be done after the 2 week FC window, and the ops used in the old stack is planned to be removed after the 6 month BC window. Test Plan: `OpInfo` tests for `batch_norm_with_update`. Reviewers: albanD, bdhirsh Subscribers: albanD, bdhirsh, supriyar Tasks: https://github.com/pytorch/pytorch/issues/111384 Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092 Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c253d1c1db
commit
7b4f70eda5
Reference in New Issue
Block a user