Revert "Batch Norm Consolidation (#116092)"

This reverts commit 7b4f70eda519ccd7f28de17689edd43c52743bc9.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
This commit is contained in:
PyTorch MergeBot
2024-03-11 22:22:39 +00:00
parent 498a94a7f5
commit fd0dbcd891
36 changed files with 72 additions and 772 deletions

View File

@ -708,11 +708,8 @@ meta_function_device_expected_failures_only_outplace = defaultdict(dict)
meta_function_device_skips = defaultdict(dict)
meta_function_device_expected_failures['cpu'] = {
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
torch.native_batch_norm: {bf16, f16},
torch._native_batch_norm_legit: {bf16, f16},
torch.ops.aten._batch_norm_with_update: {bf16, f16},
torch.native_layer_norm: {bf16, f16},
}
@ -727,11 +724,8 @@ meta_function_device_expected_failures['cuda'] = {
}
meta_function_device_skips['cpu'] = {
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
torch.native_batch_norm: {f32, f64},
torch._native_batch_norm_legit: {f32, f64},
torch.ops.aten._batch_norm_with_update: {f32, f64},
}
meta_function_device_skips['cuda'] = {
@ -856,13 +850,9 @@ meta_dispatch_device_expected_failures = defaultdict(dict)
meta_dispatch_device_skips = defaultdict(dict)
meta_dispatch_device_expected_failures['cpu'] = {
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
aten.native_batch_norm.default: {bf16, f16},
aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16},
aten._batch_norm_with_update.default: {bf16, f16},
aten.native_layer_norm.default: {bf16, f16},
aten.histc.default: {f16},
aten.histc.out: {f16},
@ -887,13 +877,9 @@ meta_dispatch_device_expected_failures['cuda'] = {
meta_dispatch_device_skips['cpu'] = {
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
aten.native_batch_norm.default: {f32, f64},
aten._native_batch_norm_legit.default: {f32, f64},
aten._native_batch_norm_legit.no_stats: {f32, f64},
aten._batch_norm_with_update.default: {f32, f64},
# If the computation dtype is different from the input
# dtype this will fail. CPU execution may also have a