From fd0dbcd891878a743872747e37cddb9f10fe9df7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 11 Mar 2024 22:22:39 +0000 Subject: [PATCH] Revert "Batch Norm Consolidation (#116092)" This reverts commit 7b4f70eda519ccd7f28de17689edd43c52743bc9. Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965)) --- aten/src/ATen/native/Normalization.cpp | 150 ++++-------------- aten/src/ATen/native/Normalization.h | 8 - aten/src/ATen/native/cuda/Normalization.cu | 78 --------- aten/src/ATen/native/cudnn/BatchNorm.cpp | 31 +--- aten/src/ATen/native/cudnn/BatchNorm.h | 6 - aten/src/ATen/native/mkldnn/Normalization.cpp | 36 ----- .../native/mps/operations/Normalization.mm | 55 ------- aten/src/ATen/native/native_functions.yaml | 26 --- test/distributed/_tensor/test_dtensor_ops.py | 1 - ...austive_batch_norm_with_update_cpu_float32 | 0 ...austive_batch_norm_with_update_cpu_float32 | 0 ...austive_batch_norm_with_update_cpu_float32 | 0 ...inplace_batch_norm_with_update_cpu_float32 | 0 ...ive_out_batch_norm_with_update_cpu_float32 | 0 ...DecompTest.test_aten_core_operators.expect | 4 - ...asDecompTest.test_has_decomposition.expect | 2 - test/functorch/test_ops.py | 105 ++++-------- test/functorch/test_vmap.py | 4 - test/inductor/test_torchinductor_opinfo.py | 1 - test/onnx/test_fx_op_consistency.py | 19 --- test/test_jit_fuser_te.py | 3 +- test/test_meta.py | 14 -- test/test_mps.py | 1 - test/test_proxy_tensor.py | 9 -- tools/autograd/derivatives.yaml | 14 -- tools/autograd/gen_python_functions.py | 5 +- torch/_C/__init__.pyi.in | 3 - torch/_decomp/decompositions.py | 136 ---------------- torch/_decomp/decompositions_for_jvp.py | 29 ---- torch/_dynamo/trace_rules.py | 1 - torch/_functorch/partitioners.py | 2 +- torch/_inductor/decomposition.py | 4 - torch/csrc/Module.cpp | 42 ----- .../serialized_shape_function_registry.cpp | 1 - torch/jit/_shape_functions.py | 5 - .../_internal/common_methods_invocations.py | 49 ------ 36 files changed, 72 insertions(+), 772 deletions(-) delete mode 100644 aten/src/ATen/native/cudnn/BatchNorm.h delete mode 100644 test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32 delete mode 100644 test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32 delete mode 100644 test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32 delete mode 100644 test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32 delete mode 100644 test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32 diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 64de95827cf2..2b00567a4e84 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -29,11 +29,6 @@ #include #include #include -#include -#include -#include -#include -#include #include #include #include @@ -484,58 +479,10 @@ std::tuple batch_norm_backward_cpu_template( return std::make_tuple(grad_input, grad_weight, grad_bias); } -BatchNormBackend _select_batch_norm_backend( - const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, - const Tensor& running_var, bool training, double eps) { - - auto& ctx = at::globalContext(); - bool cudnn_enabled = ctx.userEnabledCuDNN(); - - if ( - input.is_cuda() - && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16 - && (input.scalar_type() != at::kHalf - || weight.scalar_type() == at::kFloat) - && weight.defined() && bias.defined() - && ((running_mean.defined() && running_var.defined()) - || (!running_mean.defined() && !running_var.defined() && training)) - && (input.dim() >= 3) - && ((input.sym_size(0) <= 880801 && training) // spatial, training - ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval - && detail::getCUDAHooks().compiledWithCuDNN() - && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN() - && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L - && input.sym_numel() < std::numeric_limits::max() // some cuDNN kernels have 32-bit indexing limitations - ) { - return BatchNormBackend::Cudnn; - } - - if ( - input.is_cuda() - && input.dim() <= MIOPEN_DIM_MAX - && input.scalar_type() != at::kDouble - && input.scalar_type() != at::kBFloat16 - && (weight.scalar_type() != at::kHalf) - && weight.defined() && bias.defined() - && ((running_mean.defined() && running_var.defined()) - || (!running_mean.defined() && !running_var.defined() && training)) - && detail::getCUDAHooks().compiledWithMIOpen() - && cudnn_enabled - && input.suggest_memory_format() != MemoryFormat::ChannelsLast - && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d - ) { - return BatchNormBackend::Miopen; - } - - return BatchNormBackend::Native; -} - - // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection // of backends, while enabling it to keep the information about the used backend, so that it can // use its corresponding backward implementation. // XXX: The indices of backends need to be kept synchronized between this function and its _backward. -// TODO: remove cudnn_enabled arg std::tuple _batch_norm_impl_index( const Tensor& input, const c10::optional& weight_opt /* optional */, const c10::optional& bias_opt /* optional */, const c10::optional& running_mean_opt /* optional */, const c10::optional& running_var_opt /* optional */, bool training, double momentum, double eps, bool cudnn_enabled) { @@ -580,9 +527,24 @@ std::tuple _batch_norm_impl_index( check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel()); } - BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps); + const bool use_cudnn = ( + input.is_cuda() + && input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16 + && (input.scalar_type() != at::kHalf + || weight.scalar_type() == at::kFloat) + && weight.defined() && bias.defined() + && ((running_mean.defined() && running_var.defined()) + || (!running_mean.defined() && !running_var.defined() && training)) + && (input.dim() >= 3) + && ((input.sym_size(0) <= 880801 && training) // spatial, training + ||(input.sym_size(0) <= 65535 && !training)) //spatial, eval + && detail::getCUDAHooks().compiledWithCuDNN() + && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN() + && cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L + && input.sym_numel() < std::numeric_limits::max() // some cuDNN kernels have 32-bit indexing limitations + ); - if (backend == BatchNormBackend::Cudnn) { + if (use_cudnn) { auto input_c = input.contiguous(input.suggest_memory_format()); auto weight_c = weight.contiguous(); auto bias_c = bias.contiguous(); @@ -599,7 +561,19 @@ std::tuple _batch_norm_impl_index( Tensor reserve = at::empty({0}, input.options().dtype(kByte)); - if (backend == BatchNormBackend::Miopen) { + bool use_miopen = (input.is_cuda() + && input.dim() <= MIOPEN_DIM_MAX + && input.scalar_type() != at::kDouble + && input.scalar_type() != at::kBFloat16 + && (weight.scalar_type() != at::kHalf) + && weight.defined() && bias.defined() + && ((running_mean.defined() && running_var.defined()) + || (!running_mean.defined() && !running_var.defined() && training)) + && detail::getCUDAHooks().compiledWithMIOpen() + && cudnn_enabled + ); + + if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) { return std::tuple_cat( at::miopen_batch_norm( input.contiguous(), weight.contiguous(), bias.contiguous(), @@ -663,7 +637,6 @@ std::tuple _batch_norm_impl_index_backward( TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index); } -// TODO: remove cudnn_enabled arg Tensor batch_norm( const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, @@ -674,30 +647,6 @@ Tensor batch_norm( const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)); - // TODO: switch to the new stack after the 2 week FC window - // if (training) { - // BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps); - // if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) { - // auto input_c = input; - // if (backend == BatchNormBackend::Cudnn) { - // input_c = input.contiguous(input.suggest_memory_format()); - // } else { - // input_c = input.contiguous(); - // } - // auto weight_c = weight.contiguous(); - // auto bias_c = bias.contiguous(); - // auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean; - // auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var; - // return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast(rmean_c), - // const_cast(rvar_c), momentum, eps)); - // } else { - // return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast(running_mean), - // const_cast(running_var), momentum, eps)); - // } - // } else { - // return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var, - // momentum, eps)); - // } } Tensor instance_norm( @@ -849,38 +798,6 @@ std::tuple batch_norm_cpu(const Tensor& self, const c10: return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var); } -std::tuple _batch_norm_with_update_cpu( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - Tensor& running_mean, Tensor& running_var, double momentum, double eps) { - Tensor output, save_mean, save_var; - std::tie(output, save_mean, save_var) = - batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps); - Tensor reserve = at::empty({0}, input.options().dtype(kByte)); - return std::tuple(output, save_mean, save_var, reserve); -} - -std::tuple _batch_norm_with_update_cpu_out( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - Tensor& running_mean, Tensor& running_var, double momentum, double eps, - Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) { - std::tie(out, save_mean, save_var) = - batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var); - return std::tuple(out, save_mean, save_var, reserve); -} - - -std::tuple _batch_norm_no_update( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - const c10::optional& running_mean_opt, const c10::optional& running_var_opt, - double momentum, double eps) { - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - Tensor output, save_mean, save_var; - std::tie(output, save_mean, save_var) = - batch_norm_cpu(input, weight_opt, bias_opt, const_cast(running_mean), const_cast(running_var), /*update*/false, momentum, eps); - Tensor reserve = at::empty({0}, input.options().dtype(kByte)); - return std::tuple(output, save_mean, save_var, reserve); -} std::tuple _batch_norm_legit_cpu( const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, @@ -909,13 +826,6 @@ std::tuple _batch_norm_legit_no_stats_cpu_out(const T return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var); } -std::tuple _new_batch_norm_backward_cpu( - const Tensor& grad_output, const Tensor& input, const Tensor& weight, - const c10::optional& running_mean_opt, const c10::optional& running_var_opt, - const c10::optional& save_mean_opt, const c10::optional& save_var_opt, - bool update, double eps, std::array grad_input_mask, const Tensor& reserve) { - return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask); -} std::tuple batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional& weight_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, const c10::optional& save_mean_opt, const c10::optional& save_invstd_opt, bool train, double eps, std::array grad_input_mask) { diff --git a/aten/src/ATen/native/Normalization.h b/aten/src/ATen/native/Normalization.h index 1ba99e77b65c..6cd4dcde3705 100644 --- a/aten/src/ATen/native/Normalization.h +++ b/aten/src/ATen/native/Normalization.h @@ -8,12 +8,4 @@ namespace at::native { using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm); DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub); -enum class BatchNormBackend { - Native, - Cudnn, - Miopen, -}; - -TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps); - } // namespace at::native diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 1b4c1593e539..655d32b10a50 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -1,7 +1,5 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include -#include #include #include #include @@ -14,8 +12,6 @@ #include #include #else -#include -#include #include #include #include @@ -23,12 +19,8 @@ #include #include #include -#include -#include #include #include -#include -#include #include #include #include @@ -481,54 +473,6 @@ std::tuple batch_norm_cuda(const Tensor& self, const c10 return std::make_tuple(output, save_mean, save_invstd); } -std::tuple _batch_norm_with_update_cuda( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - Tensor& running_mean, Tensor& running_var, double momentum, double eps) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - Tensor output, save_mean, save_var, reserve; - - BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps); - if (backend == BatchNormBackend::Cudnn) { - std::tie(output, save_mean, save_var, reserve) = - at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps); - } else if (backend == BatchNormBackend::Miopen) { - reserve = at::empty({0}, input.options().dtype(kByte)); - std::tie(output, save_mean, save_var) = - at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps); - } else { - reserve = at::empty({0}, input.options().dtype(kByte)); - std::tie(output, save_mean, save_var) = - batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps); - } - return std::tuple(output, save_mean, save_var, reserve); -} - -std::tuple _batch_norm_with_update_cuda_out( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - Tensor& running_mean, Tensor& running_var, double momentum, double eps, - Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - - BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps); - if (backend == BatchNormBackend::Cudnn) { - std::tie(out, save_mean, save_var, reserve) = - at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps); - } else if (backend == BatchNormBackend::Miopen) { - std::tie(out, save_mean, save_var) = - at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps); - } else { - std::tie(out, save_mean, save_var) = - batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var); - } - return std::tuple(out, save_mean, save_var, reserve); -} - std::tuple _batch_norm_legit_cuda(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) { return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon); } @@ -545,28 +489,6 @@ std::tuple _batch_norm_legit_no_stats_cuda_out(const return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd); } -std::tuple _new_batch_norm_backward_cuda( - const Tensor& grad_output, const Tensor& input, const Tensor& weight, - const c10::optional& running_mean_opt, const c10::optional& running_var_opt, - const c10::optional& save_mean_opt, const c10::optional& save_var_opt, - bool update, double eps, std::array grad_input_mask, const Tensor& reserve) { - const Tensor& dummy_bias = at::empty(1); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); - const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();}); - - BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps); - - if (backend == BatchNormBackend::Cudnn) { - return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve); - } else if (backend == BatchNormBackend::Miopen) { - return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps); - } else { - return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask); - } -} - std::tuple batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional& weight_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, const c10::optional& save_mean_opt, const c10::optional& save_invstd_opt, bool train, double epsilon, std::array grad_input_mask) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight = at::borrow_from_optional_tensor(weight_opt); diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index af4fae44b531..eb21f8fb5a73 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #if !AT_CUDNN_ENABLED() @@ -36,24 +35,18 @@ std::tuple cudnn_batch_norm_backward( AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support"); } -size_t _get_cudnn_batch_norm_reserve_space_size( - const Tensor& input_t, - bool training) { - AT_ERROR( - "_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support"); -} - } // namespace native } // namespace at #else // AT_CUDNN_ENABLED -#include #include #include #include #include +#include + #ifndef AT_PER_OPERATOR_HEADERS #include #include @@ -98,21 +91,6 @@ cudnnBatchNormMode_t getCudnnBatchNormMode( } // namespace -size_t _get_cudnn_batch_norm_reserve_space_size( - const Tensor& input_t, - bool training) { - size_t reserve_size; - TensorArg input{input_t, "input", 1}; - TensorDescriptor idesc{*input, 4}; - auto handle = getCudnnHandle(); - cudnnBatchNormMode_t mode = getCudnnBatchNormMode( - training, input->suggest_memory_format(), input->dim()); - auto op = CUDNN_BATCHNORM_OPS_BN; - AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( - handle, mode, op, nullptr, idesc.desc(), &reserve_size)); - return reserve_size; -} - std::tuple cudnn_batch_norm( const Tensor& input_t, const Tensor& weight_t, @@ -201,8 +179,9 @@ std::tuple cudnn_batch_norm( Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte)); // get the reserved size and allocate as tensor - size_t reserve_size = - _get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */); + size_t reserve_size; + AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + handle, mode, op, nullptr, idesc.desc(), &reserve_size)); reserve = at::empty(reserve_size, input->options().dtype(kByte)); AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx( diff --git a/aten/src/ATen/native/cudnn/BatchNorm.h b/aten/src/ATen/native/cudnn/BatchNorm.h deleted file mode 100644 index 3da76c0c16e4..000000000000 --- a/aten/src/ATen/native/cudnn/BatchNorm.h +++ /dev/null @@ -1,6 +0,0 @@ -namespace at::native { - -TORCH_API size_t -_get_cudnn_batch_norm_reserve_space_size(const Tensor& input_t, bool training); - -} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 0aced614a0ea..108ce354ec9b 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -6,8 +6,6 @@ #ifndef AT_PER_OPERATOR_HEADERS #include #else -#include -#include #include #include #include @@ -61,20 +59,6 @@ std::tuple _mkldnn_batch_norm_legit_no_stats( TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support"); } -std::tuple _batch_norm_with_update_mkldnn( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - Tensor& running_mean, Tensor& running_var, double momentum, double eps) { - TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support"); -} - -std::tuple _new_batch_norm_backward_mkldnn( - const Tensor& grad_output, const Tensor& input, const Tensor& weight, - const c10::optional& running_mean_opt, const c10::optional& running_var_opt, - const c10::optional& save_mean_opt, const c10::optional& save_var_opt, - bool update, double eps, std::array grad_input_mask, const Tensor& reserve) { - TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support"); -} - } // namespace native } // namespace at @@ -208,17 +192,6 @@ std::tuple mkldnn_batch_norm( } -std::tuple _batch_norm_with_update_mkldnn( - const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, - Tensor& running_mean, Tensor& running_var, double momentum, double eps) { - Tensor output, save_mean, save_var; - std::tie(output, save_mean, save_var) = - mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps); - Tensor reserve = empty_mkldnn({0}, input.scalar_type()); - return std::tuple(output, save_mean, save_var, reserve); -} - - std::tuple _mkldnn_batch_norm_legit( const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, @@ -237,15 +210,6 @@ std::tuple _mkldnn_batch_norm_legit_no_stats( } -std::tuple _new_batch_norm_backward_mkldnn( - const Tensor& grad_output, const Tensor& input, const Tensor& weight, - const c10::optional& running_mean_opt, const c10::optional& running_var_opt, - const c10::optional& save_mean_opt, const c10::optional& save_var_opt, - bool update, double eps, std::array grad_input_mask, const Tensor& reserve) { - return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask); -} - - std::tuple mkldnn_batch_norm_backward(const Tensor& grad_output, const Tensor& input, const c10::optional& weight_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, const c10::optional& save_mean_opt, const c10::optional& save_invstd_opt, bool train, diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index bdca3b09780b..eb754ae59768 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -10,9 +10,7 @@ #include #include #else -#include #include -#include #include #include #include @@ -408,36 +406,6 @@ std::tuple batch_norm_mps(const Tensor& self, return std::make_tuple(output, save_mean, save_var); } -std::tuple _batch_norm_with_update_mps(const Tensor& input, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - Tensor& running_mean, - Tensor& running_var, - double momentum, - double eps) { - Tensor output, save_mean, save_var; - std::tie(output, save_mean, save_var) = - batch_norm_mps(input, weight_opt, bias_opt, running_mean, running_var, /*train*/ true, momentum, eps); - Tensor reserve = at::empty({0}, input.options().dtype(kByte)); - return std::tuple(output, save_mean, save_var, reserve); -} - -std::tuple _batch_norm_with_update_mps_out(const Tensor& input, - const c10::optional& weight_opt, - const c10::optional& bias_opt, - Tensor& running_mean, - Tensor& running_var, - double momentum, - double eps, - Tensor& out, - Tensor& save_mean, - Tensor& save_var, - Tensor& reserve) { - std::tie(out, save_mean, save_var) = batch_norm_mps_out( - input, weight_opt, bias_opt, running_mean, running_var, /*update*/ true, momentum, eps, out, save_mean, save_var); - return std::tuple(out, save_mean, save_var, reserve); -} - std::tuple _batch_norm_legit_mps(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, @@ -503,29 +471,6 @@ static string get_mem_string(c10::MemoryFormat memory_format) { } // Batch norm backward -std::tuple _new_batch_norm_backward_mps(const Tensor& grad_output, - const Tensor& input, - const Tensor& weight, - const c10::optional& running_mean_opt, - const c10::optional& running_var_opt, - const c10::optional& save_mean_opt, - const c10::optional& save_var_opt, - bool update, - double eps, - std::array grad_input_mask, - const Tensor& reserve) { - return batch_norm_backward_mps(grad_output, - input, - weight, - running_mean_opt, - running_var_opt, - save_mean_opt, - save_var_opt, - update, - eps, - grad_input_mask); -} - std::tuple batch_norm_backward_mps(const Tensor& grad_out, const Tensor& input, const c10::optional& weight_opt, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5d90f64e4e11..5216d80625f4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6496,32 +6496,6 @@ SparseCPU, SparseCUDA: norm_sparse autogen: native_norm.ScalarOpt_dim_dtype_out -- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) - dispatch: - CPU: _batch_norm_with_update_cpu - CUDA: _batch_norm_with_update_cuda - MPS: _batch_norm_with_update_mps - MkldnnCPU: _batch_norm_with_update_mkldnn - autogen: _batch_norm_with_update_functional - -- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) - dispatch: - CPU: _batch_norm_with_update_cpu_out - CUDA: _batch_norm_with_update_cuda_out - MPS: _batch_norm_with_update_mps_out - -- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) - dispatch: - CompositeExplicitAutograd: _batch_norm_no_update - autogen: _batch_norm_no_update.out - -- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) - dispatch: - CPU: _new_batch_norm_backward_cpu - CUDA: _new_batch_norm_backward_cuda - MPS: _new_batch_norm_backward_mps - MkldnnCPU: _new_batch_norm_backward_mkldnn - # TODO: reduce signatures down to one when optional args is available - func: _sparse_sum(Tensor self) -> Tensor diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index afc6c17e4b6e..bc7ec6dea89e 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -114,7 +114,6 @@ dtensor_fails = { xfail("as_strided", "partial_views"), xfail("as_strided_scatter"), xfail("bernoulli"), - xfail("_batch_norm_with_update"), xfail("block_diag"), xfail("broadcast_shapes"), xfail("cauchy"), diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_batch_norm_with_update_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_batch_norm_with_update_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_batch_norm_with_update_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_batch_norm_with_update_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32 b/test/dynamo_skips/TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_batch_norm_with_update_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index f23021e86681..c9f379e9a4d7 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -6,9 +6,6 @@ aten::_adaptive_avg_pool2d aten::_adaptive_avg_pool2d.out aten::_addmm_activation aten::_addmm_activation.out -aten::_batch_norm_no_update -aten::_batch_norm_with_update -aten::_batch_norm_with_update_functional aten::_euclidean_dist.out aten::_fused_dropout aten::_fused_dropout.out @@ -79,7 +76,6 @@ aten::atanh aten::atanh.out aten::atanh_ aten::baddbmm_ -aten::batch_norm_backward aten::bitwise_and.Scalar aten::bitwise_and.Scalar_Tensor aten::bitwise_and.Scalar_Tensor_out diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 181029bd73ef..d298f651d60b 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -30,8 +30,6 @@ aten::_amp_update_scale.out aten::_amp_update_scale_ aten::_assert_async aten::_assert_async.msg -aten::_batch_norm_no_update.out -aten::_batch_norm_with_update.out aten::_cdist_backward aten::_cdist_backward.out aten::_cdist_forward diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index a10acf5badd8..15d5ff0627eb 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -11,7 +11,7 @@ import unittest from torch.testing._internal.common_utils import unMarkDynamoStrictTest from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \ - IS_X86, parametrize, TEST_WITH_ASAN, TEST_WITH_ROCM, noncontiguous_like + IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like from torch.testing._internal.common_utils import skipIfRocm, runOnRocm import torch from torch import Tensor @@ -368,10 +368,6 @@ aliasing_ops_list_return = { # 'tensor_split' not composite compliant, see vjp_fail } -skip_noncontig = { - '_batch_norm_with_update', -} - @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") @unMarkDynamoStrictTest @@ -400,14 +396,6 @@ class TestOperators(TestCase): xfail('nn.functional.scaled_dot_product_attention'), xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), - - # RuntimeError: Expected contiguous tensor, but got - # non-contiguous tensor for argument #2 'grad_output' - decorate( - '_batch_norm_with_update', - decorator=expectedFailureIf(TEST_WITH_ROCM), - device_type='cuda', - ) })) @opsToleranceOverride('TestOperators', 'test_grad', ( tol1('nn.functional.binary_cross_entropy_with_logits', @@ -445,10 +433,9 @@ class TestOperators(TestCase): args = [sample.input] + list(sample.args) kwargs = sample.kwargs - if op.name not in skip_noncontig: - noncontig_sample = sample.noncontiguous() - noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args) - noncontig_kwargs = noncontig_sample.kwargs + noncontig_sample = sample.noncontiguous() + noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args) + noncontig_kwargs = noncontig_sample.kwargs diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg)) assert len(diff_argnums) > 0 @@ -471,12 +458,11 @@ class TestOperators(TestCase): return result result = grad(wrapped_fn, diff_argnums)(*args, **kwargs) + result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs) expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args) - self.assertEqual(result, expected) - if op.name not in skip_noncontig: - result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs) - self.assertEqual(result_noncontig, expected) + self.assertEqual(result, expected) + self.assertEqual(result_noncontig, expected) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @@ -490,8 +476,7 @@ class TestOperators(TestCase): skip('nn.functional.max_unpool2d'), # fails everywhere except on windows skip('nn.functional.max_unpool3d'), # fails everywhere except on mac xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents - xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents - xfail("_batch_norm_with_update"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents + xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail('nn.functional.scaled_dot_product_attention'), xfail('torch.ops.aten._flash_attention_forward'), @@ -560,17 +545,15 @@ class TestOperators(TestCase): self.jvp_opinfo_test(outplace_variant, sample, sample.output_process_fn_grad, clone_inputs=False, - fixme_ref_jvp_local=fixme_ref_jvp_local, - test_noncontig=op.name not in skip_noncontig) + fixme_ref_jvp_local=fixme_ref_jvp_local) if is_valid_inplace_sample_input(sample, op, inplace_variant): self.jvp_opinfo_test(inplace_variant, sample, sample.output_process_fn_grad, clone_inputs=True, - fixme_ref_jvp_local=fixme_ref_jvp_local, - test_noncontig=op.name not in skip_noncontig) + fixme_ref_jvp_local=fixme_ref_jvp_local) def jvp_opinfo_test(self, fn, sample, output_process_fn, - clone_inputs, fixme_ref_jvp_local, test_noncontig): + clone_inputs, fixme_ref_jvp_local): # NB: we used requires_grad=True to determine where the primals are, # but don't need that information otherwise args = (sample.input,) + sample.args @@ -580,6 +563,15 @@ class TestOperators(TestCase): orig_primals = tree_map(lambda x: x.detach(), primals) orig_tangents = tree_map(lambda x: torch.randn_like(x), primals) + noncontig_sample = sample.noncontiguous() + noncontig_args = (noncontig_sample.input,) + noncontig_sample.args + noncontig_kwargs = sample.kwargs + noncontig_fn, primals = normalize_op_input_output2( + fn, noncontig_args, noncontig_kwargs, + output_process_fn, requires_grad=True) + noncontig_primals = tree_map(lambda x: x.detach(), primals) + noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents) + def maybe_clone_inputs(): if clone_inputs: primals = tree_map(torch.clone, orig_primals) @@ -594,24 +586,15 @@ class TestOperators(TestCase): primals, tangents = maybe_clone_inputs() primal_outs, tangent_outs = jvp(contig_fn, primals, tangents) + noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn, + noncontig_primals, + noncontig_tangents) + self.assertEqual(primal_outs, expected_primal_outs) self.assertEqual(tangent_outs, expected_tangent_outs) - if test_noncontig: - noncontig_sample = sample.noncontiguous() - noncontig_args = (noncontig_sample.input,) + noncontig_sample.args - noncontig_kwargs = sample.kwargs - noncontig_fn, primals = normalize_op_input_output2( - fn, noncontig_args, noncontig_kwargs, - output_process_fn, requires_grad=True) - noncontig_primals = tree_map(lambda x: x.detach(), primals) - noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents) - noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn, - noncontig_primals, - noncontig_tangents) - - self.assertEqual(noncontig_primal_outs, expected_primal_outs) - self.assertEqual(noncontig_tangent_outs, expected_tangent_outs) + self.assertEqual(noncontig_primal_outs, expected_primal_outs) + self.assertEqual(noncontig_tangent_outs, expected_tangent_outs) @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @@ -672,22 +655,22 @@ class TestOperators(TestCase): result = fn(*primals) cotangents = tree_map(lambda x: torch.randn_like(x), result) + noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous()) + noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents) + out, vjp_fn = vjp(fn, *primals) self.assertEqual(out, result) result_vjps = vjp_fn(cotangents) + out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals) + self.assertEqual(out_noncontig, result) + noncontig_result_vjps = vjp_fn(noncontig_cotangents) + _, vjp_fn = ref_vjp(fn, *primals) expected_vjps = vjp_fn(cotangents) self.assertEqual(result_vjps, expected_vjps) - - if op.name not in skip_noncontig: - noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous()) - noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents) - out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals) - self.assertEqual(out_noncontig, result) - noncontig_result_vjps = vjp_fn(noncontig_cotangents) - self.assertEqual(noncontig_result_vjps, expected_vjps) + self.assertEqual(noncontig_result_vjps, expected_vjps) _test(op) for a_op in op.aliases: @@ -847,8 +830,6 @@ class TestOperators(TestCase): xfail("to_sparse"), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - # TODO: implement batching rule - xfail("_batch_norm_with_update"), })) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @@ -940,8 +921,6 @@ class TestOperators(TestCase): skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule skip("native_batch_norm"), skip("_native_batch_norm_legit"), - # TODO: implement batching rule - skip("_batch_norm_with_update"), xfail('__getitem__', ''), # dynamic error xfail('nanquantile', device_type='cpu'), # checks q via a .item() call xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0 @@ -1066,8 +1045,6 @@ class TestOperators(TestCase): xfail('nn.functional.batch_norm', 'without_cudnn'), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - # TODO: implement batching rule - xfail("_batch_norm_with_update"), # https://github.com/pytorch/pytorch/issues/96560 # ROCm: NotImplementedError @@ -1253,8 +1230,6 @@ class TestOperators(TestCase): xfail('sparse.mm', 'reduce'), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - # TODO: implement batching rule - xfail("_batch_norm_with_update"), xfail("native_dropout_backward"), xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled })) @@ -1331,8 +1306,6 @@ class TestOperators(TestCase): xfail('sparse.mm', 'reduce'), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - # TODO: implement batching rule - xfail("_batch_norm_with_update"), xfail('as_strided', 'partial_views'), })) def test_vjpvmap(self, device, dtype, op): @@ -1591,8 +1564,6 @@ class TestOperators(TestCase): # place, were not batched. xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - # TODO: implement batching rule - xfail("_batch_norm_with_update"), xfail('native_dropout_backward'), })) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @@ -1837,14 +1808,6 @@ class TestOperators(TestCase): skip('sparse.sampled_addmm', ''), skip('sparse.mm', 'reduce'), skip('native_layer_norm', '', device_type='cpu'), - - # RuntimeError: Expected contiguous tensor, but got - # non-contiguous tensor for argument #2 'grad_output' - decorate( - '_batch_norm_with_update', - decorator=expectedFailureIf(TEST_WITH_ROCM), - device_type='cuda', - ) }) @opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', ( tol1('linalg.householder_product', diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 1b5723f3fc2c..df46c9513a4f 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3625,8 +3625,6 @@ class TestVmapOperatorsOpInfo(TestCase): # which will be updated in place, were not batched. xfail('native_batch_norm'), xfail('_native_batch_norm_legit'), - # TODO: implement batching rule - xfail('_batch_norm_with_update'), xfail('tril'), # Exception not raised on error input xfail('triu'), # Exception not raised on error input xfail('as_strided', 'partial_views'), @@ -3666,8 +3664,6 @@ class TestVmapOperatorsOpInfo(TestCase): # which will be updated in place, were not batched. xfail('native_batch_norm'), xfail('_native_batch_norm_legit'), - # TODO: implement batching rule - xfail('_batch_norm_with_update'), xfail('histogram'), xfail('scatter_reduce', 'sum'), xfail('scatter_reduce', 'mean'), diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index c51ec4bf4288..8a35bc6c11a6 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -192,7 +192,6 @@ inductor_skips["cuda"] = { "nn.functional.cosine_embedding_loss": {b8}, "native_batch_norm": {f16, f32, f64}, "_native_batch_norm_legit": {f16, f32, f64}, - "_batch_norm_with_update": {f16, f32, f64}, } if not SM80OrLater: diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index b6a417644be4..4d38e7081d5c 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -157,11 +157,6 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = ( dtypes=(torch.float16,), reason="fixme: Assertion error: result mismatch and type error", ), - skip( - "_batch_norm_with_update", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch and type error", - ), xfail( "_softmax_backward_data", reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)") @@ -1357,20 +1352,6 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = ( model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, reason="https://github.com/pytorch/pytorch/issues/115106", ), - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]). - # Numerically the ONNX program is correct, but the output shapes for `save_mean` - # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268]) - # for example. - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - reason="not supported yet", - ), xfail( "addmm", # xfail can't only use dtypes to catch all cases matcher=lambda sample: sample.input.dtype diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 071249192ec6..635e602adab6 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -22,7 +22,7 @@ torch._C._get_graph_executor_optimize(True) from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \ - TEST_WITH_ROCM, IS_FBCODE + IS_FBCODE from torch.testing._internal.jit_utils import JitTestCase, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \ clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager @@ -2202,7 +2202,6 @@ class TestTEFuser(JitTestCase): @skipIfTorchDynamo("too slow") @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan") - @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans") def test_batch_norm(self): def test(fn, args): trace = torch.jit.trace(fn, args) diff --git a/test/test_meta.py b/test/test_meta.py index 6a092b34d778..65e17cea5fed 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -708,11 +708,8 @@ meta_function_device_expected_failures_only_outplace = defaultdict(dict) meta_function_device_skips = defaultdict(dict) meta_function_device_expected_failures['cpu'] = { - # TODO: The decomps for these batch norm ops return different dtypes depending - # on the device. We should make this work better with meta tensors. torch.native_batch_norm: {bf16, f16}, torch._native_batch_norm_legit: {bf16, f16}, - torch.ops.aten._batch_norm_with_update: {bf16, f16}, torch.native_layer_norm: {bf16, f16}, } @@ -727,11 +724,8 @@ meta_function_device_expected_failures['cuda'] = { } meta_function_device_skips['cpu'] = { - # TODO: The decomps for these batch norm ops return different dtypes depending - # on the device. We should make this work better with meta tensors. torch.native_batch_norm: {f32, f64}, torch._native_batch_norm_legit: {f32, f64}, - torch.ops.aten._batch_norm_with_update: {f32, f64}, } meta_function_device_skips['cuda'] = { @@ -856,13 +850,9 @@ meta_dispatch_device_expected_failures = defaultdict(dict) meta_dispatch_device_skips = defaultdict(dict) meta_dispatch_device_expected_failures['cpu'] = { - # TODO: The decomps for these batch norm ops return different dtypes depending - # on the device. We should make this work better with meta tensors. aten.native_batch_norm.default: {bf16, f16}, aten._native_batch_norm_legit.default: {bf16, f16}, aten._native_batch_norm_legit.no_stats: {bf16, f16}, - aten._batch_norm_with_update.default: {bf16, f16}, - aten.native_layer_norm.default: {bf16, f16}, aten.histc.default: {f16}, aten.histc.out: {f16}, @@ -887,13 +877,9 @@ meta_dispatch_device_expected_failures['cuda'] = { meta_dispatch_device_skips['cpu'] = { aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64}, - - # TODO: The decomps for these batch norm ops return different dtypes depending - # on the device. We should make this work better with meta tensors. aten.native_batch_norm.default: {f32, f64}, aten._native_batch_norm_legit.default: {f32, f64}, aten._native_batch_norm_legit.no_stats: {f32, f64}, - aten._batch_norm_with_update.default: {f32, f64}, # If the computation dtype is different from the input # dtype this will fail. CPU execution may also have a diff --git a/test/test_mps.py b/test/test_mps.py index 63bc604c581a..376bd3ad6203 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -11395,7 +11395,6 @@ class TestConsistency(TestCaseMPS): 'nn.functional.gelu', 'nn.functional.glu', '_native_batch_norm_legit', - '_batch_norm_with_update', 'native_batch_norm', 'softmax', '_softmax_backward_data', diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 08947ba66599..33383449eca1 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1927,15 +1927,6 @@ inplace_symbolic_tensor_failures = { } out_symbolic_tensor_failures = { - # Cast error details: Unable to cast (...) to Tensor - # - # This happens because the test is set up to call the out variant using the `out` kwarg: - # torch._some_op(arg1, arg2, out=(out1, out2, out3)) - # - # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`, - # this fails because the op has no python bindings, so it doesn't support the `out` kwarg - # way of calling its out variant. - xfail('_batch_norm_with_update', ''), xfail('_native_batch_norm_legit', ''), xfail('angle', ''), xfail('argmax', ''), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index f85af41ec19e..e692ae9e9234 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1250,20 +1250,6 @@ self: grad.neg() result: auto_element_wise -- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) - input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" - result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps) - -- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) - input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" - result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps) - -- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) - input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask) - save_mean: not_implemented("batch_norm_backward save_mean") - save_var: not_implemented("batch_norm_backward save_var") - reserve: not_implemented("batch_norm_backward reserve") - - name: nextafter(Tensor self, Tensor other) -> Tensor self: not_implemented("nextafter") other: not_implemented("nextafter") diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index be942ca5bfbb..9a689af79021 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -160,12 +160,9 @@ _SKIP_PYTHON_BINDINGS = [ "fill.Tensor", # only used by the functionalization pass "fill.Scalar", # only used by the functionalization pass "lift.*", - "normal_functional", # only used by the functionalization pass + "normal_functional", # only used by the functionalization pas "nbytes", "itemsize", - "_batch_norm_with_update", - "_batch_norm_with_update_out", - "_batch_norm_no_update", ] SKIP_PYTHON_BINDINGS = [ diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 002788131019..516072b531d7 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1132,7 +1132,6 @@ def _meta_in_tls_dispatch_include() -> _bool: ... def _stash_obj_in_tls(key: str, arg: Any) -> None: ... def _get_obj_in_tls(key: str) -> Any: ... def _is_key_in_tls(key: str) -> _bool: ... -def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ... def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... def _conv_determine_backend_memory_format( input: Tensor, @@ -1198,8 +1197,6 @@ class _LinalgBackend: Cusolver: _LinalgBackend Magma: _LinalgBackend -class BatchNormBackend(Enum): ... - class ConvBackend(Enum): ... class Tag(Enum): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 6b94d9f02611..0386c23d6081 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1933,114 +1933,6 @@ def _native_batch_norm_legit_functional( return output, save_mean, save_rstd, new_running_mean, new_running_var -def _get_batch_norm_reserve_tensor( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - running_mean: Tensor, - running_var: Tensor, - eps: float, - training: bool, -) -> Tensor: - """ - Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the - backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`, - which support a variety of backends including cudnn. We create this tensor here to get - the correct shape in the traced graph if we detect that will call the cudnn kernel, - and rely on DCE to avoid materializing this tensor. - """ - backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined] - input, weight, bias, running_mean, running_var, True, eps - ) - reserve_size = 0 - if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined] - reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined] - return torch.empty( - reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device - ) - - -@register_decomposition(aten._batch_norm_with_update.default) -def _batch_norm_with_update( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - running_mean: Tensor, - running_var: Tensor, - momentum: float, - eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - output, save_mean, save_rstd, _, _ = native_batch_norm_helper( - input, - weight, - bias, - running_mean, - running_var, - True, # training - momentum, - eps, - False, # functional - ) - reserve = _get_batch_norm_reserve_tensor( - input, weight, bias, running_mean, running_var, eps, training=True - ) - return output, save_mean, save_rstd, reserve - - -@register_decomposition(aten._batch_norm_with_update_functional.default) -def _batch_norm_with_update_functional( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - running_mean: Tensor, - running_var: Tensor, - momentum: float, - eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - ( - output, - save_mean, - save_rstd, - new_rm, - new_rv, - ) = native_batch_norm_helper( - input, weight, bias, running_mean, running_var, True, momentum, eps, True - ) - reserve = _get_batch_norm_reserve_tensor( - input, weight, bias, running_mean, running_var, eps, training=True - ) - assert new_rm is not None, "new_running_mean should not be None" - assert new_rv is not None, "new_running_var should not be None" - return (output, save_mean, save_rstd, reserve, new_rm, new_rv) - - -@register_decomposition(aten._batch_norm_no_update.default) -def _batch_norm_no_update( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - running_mean: Tensor, - running_var: Tensor, - momentum: float, - eps: float, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - output, save_mean, save_rstd, _, _ = native_batch_norm_helper( - input, - weight, - bias, - running_mean, - running_var, - False, # training - momentum, - eps, - False, # functional - ) - reserve = _get_batch_norm_reserve_tensor( - input, weight, bias, running_mean, running_var, eps, training=False - ) - return output, save_mean, save_rstd, reserve - - @register_decomposition(aten._fused_dropout) @out_wrapper("out0", "out1") @pw_cast_for_opmath @@ -2145,34 +2037,6 @@ def _broadcast_batch_norm_backward(x, broadcast_mask): return x -@register_decomposition(aten.batch_norm_backward.default) -def batch_norm_backward( - grad_out: Tensor, - input: Tensor, - weight: Optional[Tensor], - running_mean: Optional[Tensor], - running_var: Optional[Tensor], - save_mean: Optional[Tensor], - save_invstd: Optional[Tensor], - train: bool, - eps: float, - output_mask: List[bool], - reserve: Tensor, -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: - return native_batch_norm_backward( - grad_out, - input, - weight, - running_mean, - running_var, - save_mean, - save_invstd, - train, - eps, - output_mask, - ) - - @register_decomposition(aten.native_batch_norm_backward.default) def native_batch_norm_backward( grad_out: Tensor, diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index 81946c314638..19dfaedcce31 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -291,34 +291,6 @@ def native_batch_norm_backward( return (grad_input, grad_weight, grad_bias) -@register_decomposition_for_jvp(aten.batch_norm_backward) -def batch_norm_backward( - grad_out: Tensor, - input: Tensor, - weight: Tensor, - running_mean: Optional[Tensor], - running_var: Optional[Tensor], - save_mean: Optional[Tensor], - save_var: Optional[Tensor], - update: bool, - eps: float, - output_mask: List[bool], - reserve: Tensor, -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: - return native_batch_norm_backward( - grad_out, - input, - weight, - running_mean, - running_var, - save_mean, - save_var, - update, - eps, - output_mask, - ) - - _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) @@ -328,4 +300,3 @@ _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) -_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d2473c5d6e21..6a74d44a9572 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1005,7 +1005,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._scatter_out", "torch._C._scatter", "torch._C._select_conv_backend", - "torch._C._select_batch_norm_backend", "torch._C._set_autograd_fallback_mode", "torch._C._set_backcompat_broadcast_warn", "torch._C._set_backcompat_keepdim_warn", diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 18d0ed1d984a..dd8f86055da1 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -753,7 +753,7 @@ def min_cut_rematerialization_partition( recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops) random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] - compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._batch_norm_with_update, aten.batch_norm_backward] # noqa: E501,B950 + compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950 fusible_ops = recomputable_ops | set(random_ops) if AOT_PARTITIONER_DEBUG: diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index b3213bc8e07a..915c44d360fb 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -54,10 +54,6 @@ inductor_decompositions = get_decompositions( aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, aten._native_batch_norm_legit_no_training, - aten._batch_norm_with_update, - aten._batch_norm_with_update_functional, - aten._batch_norm_no_update, - aten.batch_norm_backward, aten.native_batch_norm, aten.native_group_norm, aten.native_layer_norm, diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ecbe85ced543..05f9ac67fe0f 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -92,10 +91,7 @@ #include #include #include - #ifdef USE_CUDA -#include -#include #include #endif @@ -2126,44 +2122,6 @@ Call this whenever a new thread is created in order to propagate values from }, "Checks if a tensor's data pointer is COW"); - py_module.def( - "_get_cudnn_batch_norm_reserve_space_size", - [](const at::Tensor& input, bool training) { -#ifdef USE_CUDA - return at::native::_get_cudnn_batch_norm_reserve_space_size( - input, training); -#else - TORCH_CHECK(false, "PyTorch was not built with cuda"); -#endif - }, - py::arg("input"), - py::arg("training")); - - py::enum_(py_module, "_BatchNormBackend") - .value("Native", at::native::BatchNormBackend::Native) - .value("Cudnn", at::native::BatchNormBackend::Cudnn) - .value("Miopen", at::native::BatchNormBackend::Miopen); - - py_module.def( - "_select_batch_norm_backend", - [](const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_var, - bool training, - double eps) { - return at::native::_select_batch_norm_backend( - input, weight, bias, running_mean, running_var, training, eps); - }, - py::arg("input"), - py::arg("weight"), - py::arg("bias"), - py::arg("running_mean"), - py::arg("running_var"), - py::arg("training"), - py::arg("eps")); - const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index 231aba3f71b2..943d43f02f73 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -3312,7 +3312,6 @@ const OperatorMap& GetShapeFunctionMappings() { {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"}, {"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"}, {"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"}, - {"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "native_batch_norm"}, {"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"}, {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"}, {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"}, diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index 9c1da982d81d..51515039866d 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1430,11 +1430,6 @@ add_shape_compute_mapping( "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", native_batch_norm, ) -add_shape_compute_mapping( - "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", - native_batch_norm, -) - add_shape_compute_mapping( "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", cross_entropy_loss, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 59cfee43faf3..630f8509918e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -509,19 +509,6 @@ def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad else: yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps)) -def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs): - samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) - for sample in samples: - # torch.native_batch_norm does not support 0 numel tensors - # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) - if sample.input.numel() == 0: - continue - args = sample.args - momentum = sample.kwargs.get('momentum', 0.5) - eps = sample.kwargs.get('eps', 1e-5) - if any(args[i] is None for i in range(4)): - continue - yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps)) def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -13111,42 +13098,6 @@ op_db: List[OpInfo] = [ "TestCompositeCompliance", "test_forward_ad"), ) ), - OpInfo('_batch_norm_with_update', - op=torch.ops.aten._batch_norm_with_update, - aten_name='_batch_norm_with_update', - dtypes=floating_types_and(torch.float16, torch.bfloat16), - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - assert_jit_shape_analysis=True, - # TODO: Avoid COW materialize - supports_cow_input_no_materialize=False, - sample_inputs_func=sample_inputs__batch_norm_with_update, - skips=( - # NotImplementedError: Could not run - # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), - # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), - # Problem with _get_numerical_jacobian - # IndexError: tuple index out of range - DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), - # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - # https://github.com/pytorch/pytorch/issues/85960 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), - DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), - "TestCompositeCompliance", "test_forward_ad"), - # _batch_norm_with_update expects contiguous inputs for cudnn and miopen - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"), - DecorateInfo(unittest.expectedFailure, - 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"), - # _batch_norm_with_update does not have python bindings - DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), - # aten out variants do not accept out= kwarg, only python out variants - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - ) - ), OpInfo('nn.functional.cosine_similarity', aten_name="cosine_similarity", dtypes=floating_types_and(torch.half, torch.bfloat16),