Revert "Batch Norm Consolidation (#116092)"

This reverts commit 7b4f70eda519ccd7f28de17689edd43c52743bc9.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
This commit is contained in:
PyTorch MergeBot
2024-03-11 22:22:39 +00:00
parent 498a94a7f5
commit fd0dbcd891
36 changed files with 72 additions and 772 deletions

View File

@ -29,11 +29,6 @@
#include <ATen/ops/_native_batch_norm_legit_native.h> #include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_native_batch_norm_legit_no_training.h> #include <ATen/ops/_native_batch_norm_legit_no_training.h>
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h> #include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
#include <ATen/ops/_batch_norm_with_update.h>
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_batch_norm_no_update.h>
#include <ATen/ops/_batch_norm_no_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/alias.h> #include <ATen/ops/alias.h>
#include <ATen/ops/batch_norm.h> #include <ATen/ops/batch_norm.h>
#include <ATen/ops/batch_norm_native.h> #include <ATen/ops/batch_norm_native.h>
@ -484,58 +479,10 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
return std::make_tuple(grad_input, grad_weight, grad_bias); return std::make_tuple(grad_input, grad_weight, grad_bias);
} }
BatchNormBackend _select_batch_norm_backend(
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
const Tensor& running_var, bool training, double eps) {
auto& ctx = at::globalContext();
bool cudnn_enabled = ctx.userEnabledCuDNN();
if (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
) {
return BatchNormBackend::Cudnn;
}
if (
input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
) {
return BatchNormBackend::Miopen;
}
return BatchNormBackend::Native;
}
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can // of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation. // use its corresponding backward implementation.
// XXX: The indices of backends need to be kept synchronized between this function and its _backward. // XXX: The indices of backends need to be kept synchronized between this function and its _backward.
// TODO: remove cudnn_enabled arg
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index( std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */, const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) { bool training, double momentum, double eps, bool cudnn_enabled) {
@ -580,9 +527,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel()); check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
} }
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps); const bool use_cudnn = (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
&& detail::getCUDAHooks().compiledWithCuDNN()
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
);
if (backend == BatchNormBackend::Cudnn) { if (use_cudnn) {
auto input_c = input.contiguous(input.suggest_memory_format()); auto input_c = input.contiguous(input.suggest_memory_format());
auto weight_c = weight.contiguous(); auto weight_c = weight.contiguous();
auto bias_c = bias.contiguous(); auto bias_c = bias.contiguous();
@ -599,7 +561,19 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
Tensor reserve = at::empty({0}, input.options().dtype(kByte)); Tensor reserve = at::empty({0}, input.options().dtype(kByte));
if (backend == BatchNormBackend::Miopen) { bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
);
if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
return std::tuple_cat( return std::tuple_cat(
at::miopen_batch_norm( at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(), input.contiguous(), weight.contiguous(), bias.contiguous(),
@ -663,7 +637,6 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index); TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
} }
// TODO: remove cudnn_enabled arg
Tensor batch_norm( Tensor batch_norm(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
@ -674,30 +647,6 @@ Tensor batch_norm(
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var, return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled)); training, momentum, eps, cudnn_enabled));
// TODO: switch to the new stack after the 2 week FC window
// if (training) {
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
// auto input_c = input;
// if (backend == BatchNormBackend::Cudnn) {
// input_c = input.contiguous(input.suggest_memory_format());
// } else {
// input_c = input.contiguous();
// }
// auto weight_c = weight.contiguous();
// auto bias_c = bias.contiguous();
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
// const_cast<Tensor&>(rvar_c), momentum, eps));
// } else {
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
// const_cast<Tensor&>(running_var), momentum, eps));
// }
// } else {
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
// momentum, eps));
// }
} }
Tensor instance_norm( Tensor instance_norm(
@ -849,38 +798,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var); return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
} }
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
std::tie(out, save_mean, save_var) =
batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
double momentum, double eps) {
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu( std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
@ -909,13 +826,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var); return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
} }
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
}
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train, double eps, std::array<bool,3> grad_input_mask) { bool train, double eps, std::array<bool,3> grad_input_mask) {

View File

@ -8,12 +8,4 @@ namespace at::native {
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm); using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub); DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
enum class BatchNormBackend {
Native,
Cudnn,
Miopen,
};
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
} // namespace at::native } // namespace at::native

View File

@ -1,7 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/Normalization.h>
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h> #include <ATen/native/ReduceOps.h>
#include <ATen/native/Resize.h> #include <ATen/native/Resize.h>
@ -14,8 +12,6 @@
#include <ATen/Functions.h> #include <ATen/Functions.h>
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
#else #else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/batch_norm_backward_elemt_native.h> #include <ATen/ops/batch_norm_backward_elemt_native.h>
#include <ATen/ops/batch_norm_backward_reduce_native.h> #include <ATen/ops/batch_norm_backward_reduce_native.h>
#include <ATen/ops/batch_norm_elemt_native.h> #include <ATen/ops/batch_norm_elemt_native.h>
@ -23,12 +19,8 @@
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h> #include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
#include <ATen/ops/batch_norm_stats_native.h> #include <ATen/ops/batch_norm_stats_native.h>
#include <ATen/ops/batch_norm_update_stats_native.h> #include <ATen/ops/batch_norm_update_stats_native.h>
#include <ATen/ops/cudnn_batch_norm.h>
#include <ATen/ops/cudnn_batch_norm_backward.h>
#include <ATen/ops/empty_like.h> #include <ATen/ops/empty_like.h>
#include <ATen/ops/from_blob.h> #include <ATen/ops/from_blob.h>
#include <ATen/ops/miopen_batch_norm.h>
#include <ATen/ops/miopen_batch_norm_backward.h>
#include <ATen/ops/native_batch_norm_backward_native.h> #include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h> #include <ATen/ops/native_batch_norm_native.h>
#include <ATen/ops/scalar_tensor.h> #include <ATen/ops/scalar_tensor.h>
@ -481,54 +473,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10
return std::make_tuple(output, save_mean, save_invstd); return std::make_tuple(output, save_mean, save_invstd);
} }
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
Tensor output, save_mean, save_var, reserve;
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(output, save_mean, save_var, reserve) =
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
}
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(out, save_mean, save_var, reserve) =
at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
std::tie(out, save_mean, save_var) =
at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
} else {
std::tie(out, save_mean, save_var) =
batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
}
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) { std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon); return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
} }
@ -545,28 +489,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd); return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
} }
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
const Tensor& dummy_bias = at::empty(1);
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
if (backend == BatchNormBackend::Cudnn) {
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
} else if (backend == BatchNormBackend::Miopen) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
} else {
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
}
}
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) { std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
// See [Note: hacky wrapper removal for optional tensor] // See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt); c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);

View File

@ -2,7 +2,6 @@
#include <ATen/Config.h> #include <ATen/Config.h>
#include <ATen/core/Tensor.h> #include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAConfig.h> #include <ATen/cuda/CUDAConfig.h>
#include <ATen/native/cudnn/BatchNorm.h>
#if !AT_CUDNN_ENABLED() #if !AT_CUDNN_ENABLED()
@ -36,24 +35,18 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support"); AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
} }
size_t _get_cudnn_batch_norm_reserve_space_size(
const Tensor& input_t,
bool training) {
AT_ERROR(
"_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support");
}
} // namespace native } // namespace native
} // namespace at } // namespace at
#else // AT_CUDNN_ENABLED #else // AT_CUDNN_ENABLED
#include <ATen/TensorUtils.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/Descriptors.h> #include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h> #include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h> #include <ATen/cudnn/Utils.h>
#include <ATen/TensorUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h> #include <ATen/Functions.h>
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
@ -98,21 +91,6 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(
} // namespace } // namespace
size_t _get_cudnn_batch_norm_reserve_space_size(
const Tensor& input_t,
bool training) {
size_t reserve_size;
TensorArg input{input_t, "input", 1};
TensorDescriptor idesc{*input, 4};
auto handle = getCudnnHandle();
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
training, input->suggest_memory_format(), input->dim());
auto op = CUDNN_BATCHNORM_OPS_BN;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
handle, mode, op, nullptr, idesc.desc(), &reserve_size));
return reserve_size;
}
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm( std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input_t, const Tensor& input_t,
const Tensor& weight_t, const Tensor& weight_t,
@ -201,8 +179,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte)); Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
// get the reserved size and allocate as tensor // get the reserved size and allocate as tensor
size_t reserve_size = size_t reserve_size;
_get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */); AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
handle, mode, op, nullptr, idesc.desc(), &reserve_size));
reserve = at::empty(reserve_size, input->options().dtype(kByte)); reserve = at::empty(reserve_size, input->options().dtype(kByte));
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx( AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(

View File

@ -1,6 +0,0 @@
namespace at::native {
TORCH_API size_t
_get_cudnn_batch_norm_reserve_space_size(const Tensor& input_t, bool training);
} // namespace at::native

View File

@ -6,8 +6,6 @@
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
#else #else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/_native_batch_norm_legit_native.h> #include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_to_dense_native.h> #include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty_native.h> #include <ATen/ops/empty_native.h>
@ -61,20 +59,6 @@ std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support"); TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
} }
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
}
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support");
}
} // namespace native } // namespace native
} // namespace at } // namespace at
@ -208,17 +192,6 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
} }
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps);
Tensor reserve = empty_mkldnn({0}, input.scalar_type());
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit( std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
bool train, bool train,
@ -237,15 +210,6 @@ std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
} }
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
}
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output, std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
bool train, bool train,

View File

@ -10,9 +10,7 @@
#include <ATen/Functions.h> #include <ATen/Functions.h>
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
#else #else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_native_batch_norm_legit_native.h> #include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm.h> #include <ATen/ops/native_batch_norm.h>
#include <ATen/ops/native_batch_norm_backward_native.h> #include <ATen/ops/native_batch_norm_backward_native.h>
#include <ATen/ops/native_batch_norm_native.h> #include <ATen/ops/native_batch_norm_native.h>
@ -408,36 +406,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_mps(const Tensor& self,
return std::make_tuple(output, save_mean, save_var); return std::make_tuple(output, save_mean, save_var);
} }
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mps(const Tensor& input,
const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& bias_opt,
Tensor& running_mean,
Tensor& running_var,
double momentum,
double eps) {
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
batch_norm_mps(input, weight_opt, bias_opt, running_mean, running_var, /*train*/ true, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_mps_out(const Tensor& input,
const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& bias_opt,
Tensor& running_mean,
Tensor& running_var,
double momentum,
double eps,
Tensor& out,
Tensor& save_mean,
Tensor& save_var,
Tensor& reserve) {
std::tie(out, save_mean, save_var) = batch_norm_mps_out(
input, weight_opt, bias_opt, running_mean, running_var, /*update*/ true, momentum, eps, out, save_mean, save_var);
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self, std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& bias_opt, const c10::optional<Tensor>& bias_opt,
@ -503,29 +471,6 @@ static string get_mem_string(c10::MemoryFormat memory_format) {
} }
// Batch norm backward // Batch norm backward
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
const c10::optional<Tensor>& running_mean_opt,
const c10::optional<Tensor>& running_var_opt,
const c10::optional<Tensor>& save_mean_opt,
const c10::optional<Tensor>& save_var_opt,
bool update,
double eps,
std::array<bool, 3> grad_input_mask,
const Tensor& reserve) {
return batch_norm_backward_mps(grad_output,
input,
weight,
running_mean_opt,
running_var_opt,
save_mean_opt,
save_var_opt,
update,
eps,
grad_input_mask);
}
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out, std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out,
const Tensor& input, const Tensor& input,
const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& weight_opt,

View File

@ -6496,32 +6496,6 @@
SparseCPU, SparseCUDA: norm_sparse SparseCPU, SparseCUDA: norm_sparse
autogen: native_norm.ScalarOpt_dim_dtype_out autogen: native_norm.ScalarOpt_dim_dtype_out
- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CPU: _batch_norm_with_update_cpu
CUDA: _batch_norm_with_update_cuda
MPS: _batch_norm_with_update_mps
MkldnnCPU: _batch_norm_with_update_mkldnn
autogen: _batch_norm_with_update_functional
- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
dispatch:
CPU: _batch_norm_with_update_cpu_out
CUDA: _batch_norm_with_update_cuda_out
MPS: _batch_norm_with_update_mps_out
- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CompositeExplicitAutograd: _batch_norm_no_update
autogen: _batch_norm_no_update.out
- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: _new_batch_norm_backward_cpu
CUDA: _new_batch_norm_backward_cuda
MPS: _new_batch_norm_backward_mps
MkldnnCPU: _new_batch_norm_backward_mkldnn
# TODO: reduce signatures down to one when optional args is available # TODO: reduce signatures down to one when optional args is available
- func: _sparse_sum(Tensor self) -> Tensor - func: _sparse_sum(Tensor self) -> Tensor

View File

@ -114,7 +114,6 @@ dtensor_fails = {
xfail("as_strided", "partial_views"), xfail("as_strided", "partial_views"),
xfail("as_strided_scatter"), xfail("as_strided_scatter"),
xfail("bernoulli"), xfail("bernoulli"),
xfail("_batch_norm_with_update"),
xfail("block_diag"), xfail("block_diag"),
xfail("broadcast_shapes"), xfail("broadcast_shapes"),
xfail("cauchy"), xfail("cauchy"),

View File

@ -6,9 +6,6 @@ aten::_adaptive_avg_pool2d
aten::_adaptive_avg_pool2d.out aten::_adaptive_avg_pool2d.out
aten::_addmm_activation aten::_addmm_activation
aten::_addmm_activation.out aten::_addmm_activation.out
aten::_batch_norm_no_update
aten::_batch_norm_with_update
aten::_batch_norm_with_update_functional
aten::_euclidean_dist.out aten::_euclidean_dist.out
aten::_fused_dropout aten::_fused_dropout
aten::_fused_dropout.out aten::_fused_dropout.out
@ -79,7 +76,6 @@ aten::atanh
aten::atanh.out aten::atanh.out
aten::atanh_ aten::atanh_
aten::baddbmm_ aten::baddbmm_
aten::batch_norm_backward
aten::bitwise_and.Scalar aten::bitwise_and.Scalar
aten::bitwise_and.Scalar_Tensor aten::bitwise_and.Scalar_Tensor
aten::bitwise_and.Scalar_Tensor_out aten::bitwise_and.Scalar_Tensor_out

View File

@ -30,8 +30,6 @@ aten::_amp_update_scale.out
aten::_amp_update_scale_ aten::_amp_update_scale_
aten::_assert_async aten::_assert_async
aten::_assert_async.msg aten::_assert_async.msg
aten::_batch_norm_no_update.out
aten::_batch_norm_with_update.out
aten::_cdist_backward aten::_cdist_backward
aten::_cdist_backward.out aten::_cdist_backward.out
aten::_cdist_forward aten::_cdist_forward

View File

@ -11,7 +11,7 @@ import unittest
from torch.testing._internal.common_utils import unMarkDynamoStrictTest from torch.testing._internal.common_utils import unMarkDynamoStrictTest
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \ from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \
IS_X86, parametrize, TEST_WITH_ASAN, TEST_WITH_ROCM, noncontiguous_like IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like
from torch.testing._internal.common_utils import skipIfRocm, runOnRocm from torch.testing._internal.common_utils import skipIfRocm, runOnRocm
import torch import torch
from torch import Tensor from torch import Tensor
@ -368,10 +368,6 @@ aliasing_ops_list_return = {
# 'tensor_split' not composite compliant, see vjp_fail # 'tensor_split' not composite compliant, see vjp_fail
} }
skip_noncontig = {
'_batch_norm_with_update',
}
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
@unMarkDynamoStrictTest @unMarkDynamoStrictTest
@ -400,14 +396,6 @@ class TestOperators(TestCase):
xfail('nn.functional.scaled_dot_product_attention'), xfail('nn.functional.scaled_dot_product_attention'),
xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._flash_attention_forward"),
xfail("torch.ops.aten._efficient_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"),
# RuntimeError: Expected contiguous tensor, but got
# non-contiguous tensor for argument #2 'grad_output'
decorate(
'_batch_norm_with_update',
decorator=expectedFailureIf(TEST_WITH_ROCM),
device_type='cuda',
)
})) }))
@opsToleranceOverride('TestOperators', 'test_grad', ( @opsToleranceOverride('TestOperators', 'test_grad', (
tol1('nn.functional.binary_cross_entropy_with_logits', tol1('nn.functional.binary_cross_entropy_with_logits',
@ -445,10 +433,9 @@ class TestOperators(TestCase):
args = [sample.input] + list(sample.args) args = [sample.input] + list(sample.args)
kwargs = sample.kwargs kwargs = sample.kwargs
if op.name not in skip_noncontig: noncontig_sample = sample.noncontiguous()
noncontig_sample = sample.noncontiguous() noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args) noncontig_kwargs = noncontig_sample.kwargs
noncontig_kwargs = noncontig_sample.kwargs
diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg)) diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
assert len(diff_argnums) > 0 assert len(diff_argnums) > 0
@ -471,12 +458,11 @@ class TestOperators(TestCase):
return result return result
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs) result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args) expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
self.assertEqual(result, expected)
if op.name not in skip_noncontig: self.assertEqual(result, expected)
result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs) self.assertEqual(result_noncontig, expected)
self.assertEqual(result_noncontig, expected)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@ -490,8 +476,7 @@ class TestOperators(TestCase):
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail("_batch_norm_with_update"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail('nn.functional.scaled_dot_product_attention'), xfail('nn.functional.scaled_dot_product_attention'),
xfail('torch.ops.aten._flash_attention_forward'), xfail('torch.ops.aten._flash_attention_forward'),
@ -560,17 +545,15 @@ class TestOperators(TestCase):
self.jvp_opinfo_test(outplace_variant, sample, self.jvp_opinfo_test(outplace_variant, sample,
sample.output_process_fn_grad, sample.output_process_fn_grad,
clone_inputs=False, clone_inputs=False,
fixme_ref_jvp_local=fixme_ref_jvp_local, fixme_ref_jvp_local=fixme_ref_jvp_local)
test_noncontig=op.name not in skip_noncontig)
if is_valid_inplace_sample_input(sample, op, inplace_variant): if is_valid_inplace_sample_input(sample, op, inplace_variant):
self.jvp_opinfo_test(inplace_variant, sample, self.jvp_opinfo_test(inplace_variant, sample,
sample.output_process_fn_grad, sample.output_process_fn_grad,
clone_inputs=True, clone_inputs=True,
fixme_ref_jvp_local=fixme_ref_jvp_local, fixme_ref_jvp_local=fixme_ref_jvp_local)
test_noncontig=op.name not in skip_noncontig)
def jvp_opinfo_test(self, fn, sample, output_process_fn, def jvp_opinfo_test(self, fn, sample, output_process_fn,
clone_inputs, fixme_ref_jvp_local, test_noncontig): clone_inputs, fixme_ref_jvp_local):
# NB: we used requires_grad=True to determine where the primals are, # NB: we used requires_grad=True to determine where the primals are,
# but don't need that information otherwise # but don't need that information otherwise
args = (sample.input,) + sample.args args = (sample.input,) + sample.args
@ -580,6 +563,15 @@ class TestOperators(TestCase):
orig_primals = tree_map(lambda x: x.detach(), primals) orig_primals = tree_map(lambda x: x.detach(), primals)
orig_tangents = tree_map(lambda x: torch.randn_like(x), primals) orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
noncontig_sample = sample.noncontiguous()
noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
noncontig_kwargs = sample.kwargs
noncontig_fn, primals = normalize_op_input_output2(
fn, noncontig_args, noncontig_kwargs,
output_process_fn, requires_grad=True)
noncontig_primals = tree_map(lambda x: x.detach(), primals)
noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
def maybe_clone_inputs(): def maybe_clone_inputs():
if clone_inputs: if clone_inputs:
primals = tree_map(torch.clone, orig_primals) primals = tree_map(torch.clone, orig_primals)
@ -594,24 +586,15 @@ class TestOperators(TestCase):
primals, tangents = maybe_clone_inputs() primals, tangents = maybe_clone_inputs()
primal_outs, tangent_outs = jvp(contig_fn, primals, tangents) primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
noncontig_primals,
noncontig_tangents)
self.assertEqual(primal_outs, expected_primal_outs) self.assertEqual(primal_outs, expected_primal_outs)
self.assertEqual(tangent_outs, expected_tangent_outs) self.assertEqual(tangent_outs, expected_tangent_outs)
if test_noncontig: self.assertEqual(noncontig_primal_outs, expected_primal_outs)
noncontig_sample = sample.noncontiguous() self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
noncontig_kwargs = sample.kwargs
noncontig_fn, primals = normalize_op_input_output2(
fn, noncontig_args, noncontig_kwargs,
output_process_fn, requires_grad=True)
noncontig_primals = tree_map(lambda x: x.detach(), primals)
noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
noncontig_primals,
noncontig_tangents)
self.assertEqual(noncontig_primal_outs, expected_primal_outs)
self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@ -672,22 +655,22 @@ class TestOperators(TestCase):
result = fn(*primals) result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result) cotangents = tree_map(lambda x: torch.randn_like(x), result)
noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
out, vjp_fn = vjp(fn, *primals) out, vjp_fn = vjp(fn, *primals)
self.assertEqual(out, result) self.assertEqual(out, result)
result_vjps = vjp_fn(cotangents) result_vjps = vjp_fn(cotangents)
out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
self.assertEqual(out_noncontig, result)
noncontig_result_vjps = vjp_fn(noncontig_cotangents)
_, vjp_fn = ref_vjp(fn, *primals) _, vjp_fn = ref_vjp(fn, *primals)
expected_vjps = vjp_fn(cotangents) expected_vjps = vjp_fn(cotangents)
self.assertEqual(result_vjps, expected_vjps) self.assertEqual(result_vjps, expected_vjps)
self.assertEqual(noncontig_result_vjps, expected_vjps)
if op.name not in skip_noncontig:
noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
self.assertEqual(out_noncontig, result)
noncontig_result_vjps = vjp_fn(noncontig_cotangents)
self.assertEqual(noncontig_result_vjps, expected_vjps)
_test(op) _test(op)
for a_op in op.aliases: for a_op in op.aliases:
@ -847,8 +830,6 @@ class TestOperators(TestCase):
xfail("to_sparse"), xfail("to_sparse"),
xfail("native_batch_norm"), xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
})) }))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@ -940,8 +921,6 @@ class TestOperators(TestCase):
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
skip("native_batch_norm"), skip("native_batch_norm"),
skip("_native_batch_norm_legit"), skip("_native_batch_norm_legit"),
# TODO: implement batching rule
skip("_batch_norm_with_update"),
xfail('__getitem__', ''), # dynamic error xfail('__getitem__', ''), # dynamic error
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0 xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
@ -1066,8 +1045,6 @@ class TestOperators(TestCase):
xfail('nn.functional.batch_norm', 'without_cudnn'), xfail('nn.functional.batch_norm', 'without_cudnn'),
xfail("native_batch_norm"), xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
# https://github.com/pytorch/pytorch/issues/96560 # https://github.com/pytorch/pytorch/issues/96560
# ROCm: NotImplementedError # ROCm: NotImplementedError
@ -1253,8 +1230,6 @@ class TestOperators(TestCase):
xfail('sparse.mm', 'reduce'), xfail('sparse.mm', 'reduce'),
xfail("native_batch_norm"), xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
xfail("native_dropout_backward"), xfail("native_dropout_backward"),
xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled
})) }))
@ -1331,8 +1306,6 @@ class TestOperators(TestCase):
xfail('sparse.mm', 'reduce'), xfail('sparse.mm', 'reduce'),
xfail("native_batch_norm"), xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
xfail('as_strided', 'partial_views'), xfail('as_strided', 'partial_views'),
})) }))
def test_vjpvmap(self, device, dtype, op): def test_vjpvmap(self, device, dtype, op):
@ -1591,8 +1564,6 @@ class TestOperators(TestCase):
# place, were not batched. # place, were not batched.
xfail("native_batch_norm"), xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"), xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
xfail('native_dropout_backward'), xfail('native_dropout_backward'),
})) }))
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@ -1837,14 +1808,6 @@ class TestOperators(TestCase):
skip('sparse.sampled_addmm', ''), skip('sparse.sampled_addmm', ''),
skip('sparse.mm', 'reduce'), skip('sparse.mm', 'reduce'),
skip('native_layer_norm', '', device_type='cpu'), skip('native_layer_norm', '', device_type='cpu'),
# RuntimeError: Expected contiguous tensor, but got
# non-contiguous tensor for argument #2 'grad_output'
decorate(
'_batch_norm_with_update',
decorator=expectedFailureIf(TEST_WITH_ROCM),
device_type='cuda',
)
}) })
@opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', ( @opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', (
tol1('linalg.householder_product', tol1('linalg.householder_product',

View File

@ -3625,8 +3625,6 @@ class TestVmapOperatorsOpInfo(TestCase):
# which will be updated in place, were not batched. # which will be updated in place, were not batched.
xfail('native_batch_norm'), xfail('native_batch_norm'),
xfail('_native_batch_norm_legit'), xfail('_native_batch_norm_legit'),
# TODO: implement batching rule
xfail('_batch_norm_with_update'),
xfail('tril'), # Exception not raised on error input xfail('tril'), # Exception not raised on error input
xfail('triu'), # Exception not raised on error input xfail('triu'), # Exception not raised on error input
xfail('as_strided', 'partial_views'), xfail('as_strided', 'partial_views'),
@ -3666,8 +3664,6 @@ class TestVmapOperatorsOpInfo(TestCase):
# which will be updated in place, were not batched. # which will be updated in place, were not batched.
xfail('native_batch_norm'), xfail('native_batch_norm'),
xfail('_native_batch_norm_legit'), xfail('_native_batch_norm_legit'),
# TODO: implement batching rule
xfail('_batch_norm_with_update'),
xfail('histogram'), xfail('histogram'),
xfail('scatter_reduce', 'sum'), xfail('scatter_reduce', 'sum'),
xfail('scatter_reduce', 'mean'), xfail('scatter_reduce', 'mean'),

View File

@ -192,7 +192,6 @@ inductor_skips["cuda"] = {
"nn.functional.cosine_embedding_loss": {b8}, "nn.functional.cosine_embedding_loss": {b8},
"native_batch_norm": {f16, f32, f64}, "native_batch_norm": {f16, f32, f64},
"_native_batch_norm_legit": {f16, f32, f64}, "_native_batch_norm_legit": {f16, f32, f64},
"_batch_norm_with_update": {f16, f32, f64},
} }
if not SM80OrLater: if not SM80OrLater:

View File

@ -157,11 +157,6 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
dtypes=(torch.float16,), dtypes=(torch.float16,),
reason="fixme: Assertion error: result mismatch and type error", reason="fixme: Assertion error: result mismatch and type error",
), ),
skip(
"_batch_norm_with_update",
dtypes=(torch.float16,),
reason="fixme: Assertion error: result mismatch and type error",
),
xfail( xfail(
"_softmax_backward_data", "_softmax_backward_data",
reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)") reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
@ -1357,20 +1352,6 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="https://github.com/pytorch/pytorch/issues/115106", reason="https://github.com/pytorch/pytorch/issues/115106",
), ),
skip(
"_batch_norm_with_update",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
reason="https://github.com/pytorch/pytorch/issues/115106",
),
# TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]).
# Numerically the ONNX program is correct, but the output shapes for `save_mean`
# and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268])
# for example.
skip(
"_batch_norm_with_update",
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
reason="not supported yet",
),
xfail( xfail(
"addmm", # xfail can't only use dtypes to catch all cases "addmm", # xfail can't only use dtypes to catch all cases
matcher=lambda sample: sample.input.dtype matcher=lambda sample: sample.input.dtype

View File

@ -22,7 +22,7 @@ torch._C._get_graph_executor_optimize(True)
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \ enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \
TEST_WITH_ROCM, IS_FBCODE IS_FBCODE
from torch.testing._internal.jit_utils import JitTestCase, \ from torch.testing._internal.jit_utils import JitTestCase, \
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
@ -2202,7 +2202,6 @@ class TestTEFuser(JitTestCase):
@skipIfTorchDynamo("too slow") @skipIfTorchDynamo("too slow")
@unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan") @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
@unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans")
def test_batch_norm(self): def test_batch_norm(self):
def test(fn, args): def test(fn, args):
trace = torch.jit.trace(fn, args) trace = torch.jit.trace(fn, args)

View File

@ -708,11 +708,8 @@ meta_function_device_expected_failures_only_outplace = defaultdict(dict)
meta_function_device_skips = defaultdict(dict) meta_function_device_skips = defaultdict(dict)
meta_function_device_expected_failures['cpu'] = { meta_function_device_expected_failures['cpu'] = {
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
torch.native_batch_norm: {bf16, f16}, torch.native_batch_norm: {bf16, f16},
torch._native_batch_norm_legit: {bf16, f16}, torch._native_batch_norm_legit: {bf16, f16},
torch.ops.aten._batch_norm_with_update: {bf16, f16},
torch.native_layer_norm: {bf16, f16}, torch.native_layer_norm: {bf16, f16},
} }
@ -727,11 +724,8 @@ meta_function_device_expected_failures['cuda'] = {
} }
meta_function_device_skips['cpu'] = { meta_function_device_skips['cpu'] = {
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
torch.native_batch_norm: {f32, f64}, torch.native_batch_norm: {f32, f64},
torch._native_batch_norm_legit: {f32, f64}, torch._native_batch_norm_legit: {f32, f64},
torch.ops.aten._batch_norm_with_update: {f32, f64},
} }
meta_function_device_skips['cuda'] = { meta_function_device_skips['cuda'] = {
@ -856,13 +850,9 @@ meta_dispatch_device_expected_failures = defaultdict(dict)
meta_dispatch_device_skips = defaultdict(dict) meta_dispatch_device_skips = defaultdict(dict)
meta_dispatch_device_expected_failures['cpu'] = { meta_dispatch_device_expected_failures['cpu'] = {
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
aten.native_batch_norm.default: {bf16, f16}, aten.native_batch_norm.default: {bf16, f16},
aten._native_batch_norm_legit.default: {bf16, f16}, aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16}, aten._native_batch_norm_legit.no_stats: {bf16, f16},
aten._batch_norm_with_update.default: {bf16, f16},
aten.native_layer_norm.default: {bf16, f16}, aten.native_layer_norm.default: {bf16, f16},
aten.histc.default: {f16}, aten.histc.default: {f16},
aten.histc.out: {f16}, aten.histc.out: {f16},
@ -887,13 +877,9 @@ meta_dispatch_device_expected_failures['cuda'] = {
meta_dispatch_device_skips['cpu'] = { meta_dispatch_device_skips['cpu'] = {
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64}, aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
# TODO: The decomps for these batch norm ops return different dtypes depending
# on the device. We should make this work better with meta tensors.
aten.native_batch_norm.default: {f32, f64}, aten.native_batch_norm.default: {f32, f64},
aten._native_batch_norm_legit.default: {f32, f64}, aten._native_batch_norm_legit.default: {f32, f64},
aten._native_batch_norm_legit.no_stats: {f32, f64}, aten._native_batch_norm_legit.no_stats: {f32, f64},
aten._batch_norm_with_update.default: {f32, f64},
# If the computation dtype is different from the input # If the computation dtype is different from the input
# dtype this will fail. CPU execution may also have a # dtype this will fail. CPU execution may also have a

View File

@ -11395,7 +11395,6 @@ class TestConsistency(TestCaseMPS):
'nn.functional.gelu', 'nn.functional.gelu',
'nn.functional.glu', 'nn.functional.glu',
'_native_batch_norm_legit', '_native_batch_norm_legit',
'_batch_norm_with_update',
'native_batch_norm', 'native_batch_norm',
'softmax', 'softmax',
'_softmax_backward_data', '_softmax_backward_data',

View File

@ -1927,15 +1927,6 @@ inplace_symbolic_tensor_failures = {
} }
out_symbolic_tensor_failures = { out_symbolic_tensor_failures = {
# Cast error details: Unable to cast (...) to Tensor
#
# This happens because the test is set up to call the out variant using the `out` kwarg:
# torch._some_op(arg1, arg2, out=(out1, out2, out3))
#
# However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
# this fails because the op has no python bindings, so it doesn't support the `out` kwarg
# way of calling its out variant.
xfail('_batch_norm_with_update', ''),
xfail('_native_batch_norm_legit', ''), xfail('_native_batch_norm_legit', ''),
xfail('angle', ''), xfail('angle', ''),
xfail('argmax', ''), xfail('argmax', ''),

View File

@ -1250,20 +1250,6 @@
self: grad.neg() self: grad.neg()
result: auto_element_wise result: auto_element_wise
- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
save_mean: not_implemented("batch_norm_backward save_mean")
save_var: not_implemented("batch_norm_backward save_var")
reserve: not_implemented("batch_norm_backward reserve")
- name: nextafter(Tensor self, Tensor other) -> Tensor - name: nextafter(Tensor self, Tensor other) -> Tensor
self: not_implemented("nextafter") self: not_implemented("nextafter")
other: not_implemented("nextafter") other: not_implemented("nextafter")

View File

@ -160,12 +160,9 @@ _SKIP_PYTHON_BINDINGS = [
"fill.Tensor", # only used by the functionalization pass "fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass "fill.Scalar", # only used by the functionalization pass
"lift.*", "lift.*",
"normal_functional", # only used by the functionalization pass "normal_functional", # only used by the functionalization pas
"nbytes", "nbytes",
"itemsize", "itemsize",
"_batch_norm_with_update",
"_batch_norm_with_update_out",
"_batch_norm_no_update",
] ]
SKIP_PYTHON_BINDINGS = [ SKIP_PYTHON_BINDINGS = [

View File

@ -1132,7 +1132,6 @@ def _meta_in_tls_dispatch_include() -> _bool: ...
def _stash_obj_in_tls(key: str, arg: Any) -> None: ... def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
def _get_obj_in_tls(key: str) -> Any: ... def _get_obj_in_tls(key: str) -> Any: ...
def _is_key_in_tls(key: str) -> _bool: ... def _is_key_in_tls(key: str) -> _bool: ...
def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
def _conv_determine_backend_memory_format( def _conv_determine_backend_memory_format(
input: Tensor, input: Tensor,
@ -1198,8 +1197,6 @@ class _LinalgBackend:
Cusolver: _LinalgBackend Cusolver: _LinalgBackend
Magma: _LinalgBackend Magma: _LinalgBackend
class BatchNormBackend(Enum): ...
class ConvBackend(Enum): ... class ConvBackend(Enum): ...
class Tag(Enum): class Tag(Enum):

View File

@ -1933,114 +1933,6 @@ def _native_batch_norm_legit_functional(
return output, save_mean, save_rstd, new_running_mean, new_running_var return output, save_mean, save_rstd, new_running_mean, new_running_var
def _get_batch_norm_reserve_tensor(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Tensor,
running_var: Tensor,
eps: float,
training: bool,
) -> Tensor:
"""
Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
which support a variety of backends including cudnn. We create this tensor here to get
the correct shape in the traced graph if we detect that will call the cudnn kernel,
and rely on DCE to avoid materializing this tensor.
"""
backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined]
input, weight, bias, running_mean, running_var, True, eps
)
reserve_size = 0
if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined]
reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined]
return torch.empty(
reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
)
@register_decomposition(aten._batch_norm_with_update.default)
def _batch_norm_with_update(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Tensor,
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
bias,
running_mean,
running_var,
True, # training
momentum,
eps,
False, # functional
)
reserve = _get_batch_norm_reserve_tensor(
input, weight, bias, running_mean, running_var, eps, training=True
)
return output, save_mean, save_rstd, reserve
@register_decomposition(aten._batch_norm_with_update_functional.default)
def _batch_norm_with_update_functional(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Tensor,
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
(
output,
save_mean,
save_rstd,
new_rm,
new_rv,
) = native_batch_norm_helper(
input, weight, bias, running_mean, running_var, True, momentum, eps, True
)
reserve = _get_batch_norm_reserve_tensor(
input, weight, bias, running_mean, running_var, eps, training=True
)
assert new_rm is not None, "new_running_mean should not be None"
assert new_rv is not None, "new_running_var should not be None"
return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
@register_decomposition(aten._batch_norm_no_update.default)
def _batch_norm_no_update(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Tensor,
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
bias,
running_mean,
running_var,
False, # training
momentum,
eps,
False, # functional
)
reserve = _get_batch_norm_reserve_tensor(
input, weight, bias, running_mean, running_var, eps, training=False
)
return output, save_mean, save_rstd, reserve
@register_decomposition(aten._fused_dropout) @register_decomposition(aten._fused_dropout)
@out_wrapper("out0", "out1") @out_wrapper("out0", "out1")
@pw_cast_for_opmath @pw_cast_for_opmath
@ -2145,34 +2037,6 @@ def _broadcast_batch_norm_backward(x, broadcast_mask):
return x return x
@register_decomposition(aten.batch_norm_backward.default)
def batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,
weight,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
)
@register_decomposition(aten.native_batch_norm_backward.default) @register_decomposition(aten.native_batch_norm_backward.default)
def native_batch_norm_backward( def native_batch_norm_backward(
grad_out: Tensor, grad_out: Tensor,

View File

@ -291,34 +291,6 @@ def native_batch_norm_backward(
return (grad_input, grad_weight, grad_bias) return (grad_input, grad_weight, grad_bias)
@register_decomposition_for_jvp(aten.batch_norm_backward)
def batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_var: Optional[Tensor],
update: bool,
eps: float,
output_mask: List[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,
weight,
running_mean,
running_var,
save_mean,
save_var,
update,
eps,
output_mask,
)
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) _register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
@ -328,4 +300,3 @@ _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)

View File

@ -1005,7 +1005,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._scatter_out", "torch._C._scatter_out",
"torch._C._scatter", "torch._C._scatter",
"torch._C._select_conv_backend", "torch._C._select_conv_backend",
"torch._C._select_batch_norm_backend",
"torch._C._set_autograd_fallback_mode", "torch._C._set_autograd_fallback_mode",
"torch._C._set_backcompat_broadcast_warn", "torch._C._set_backcompat_broadcast_warn",
"torch._C._set_backcompat_keepdim_warn", "torch._C._set_backcompat_keepdim_warn",

View File

@ -753,7 +753,7 @@ def min_cut_rematerialization_partition(
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops) recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._batch_norm_with_update, aten.batch_norm_backward] # noqa: E501,B950 compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950
fusible_ops = recomputable_ops | set(random_ops) fusible_ops = recomputable_ops | set(random_ops)
if AOT_PARTITIONER_DEBUG: if AOT_PARTITIONER_DEBUG:

View File

@ -54,10 +54,6 @@ inductor_decompositions = get_decompositions(
aten._native_batch_norm_legit, aten._native_batch_norm_legit,
aten._native_batch_norm_legit_functional, aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training, aten._native_batch_norm_legit_no_training,
aten._batch_norm_with_update,
aten._batch_norm_with_update_functional,
aten._batch_norm_no_update,
aten.batch_norm_backward,
aten.native_batch_norm, aten.native_batch_norm,
aten.native_group_norm, aten.native_group_norm,
aten.native_layer_norm, aten.native_layer_norm,

View File

@ -18,7 +18,6 @@
#include <ATen/dlpack.h> #include <ATen/dlpack.h>
#include <ATen/native/ConvUtils.h> #include <ATen/native/ConvUtils.h>
#include <ATen/native/ForeachUtils.h> #include <ATen/native/ForeachUtils.h>
#include <ATen/native/Normalization.h>
#include <c10/core/DispatchKeySet.h> #include <c10/core/DispatchKeySet.h>
#include <c10/util/AbortHandler.h> #include <c10/util/AbortHandler.h>
#include <c10/util/Backtrace.h> #include <c10/util/Backtrace.h>
@ -92,10 +91,7 @@
#include <ATen/native/transformers/sdp_utils_cpp.h> #include <ATen/native/transformers/sdp_utils_cpp.h>
#include <torch/csrc/profiler/combined_traceback.h> #include <torch/csrc/profiler/combined_traceback.h>
#include <sstream> #include <sstream>
#ifdef USE_CUDA #ifdef USE_CUDA
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/native/cudnn/BatchNorm.h>
#include <ATen/native/transformers/cuda/sdp_utils.h> #include <ATen/native/transformers/cuda/sdp_utils.h>
#endif #endif
@ -2126,44 +2122,6 @@ Call this whenever a new thread is created in order to propagate values from
}, },
"Checks if a tensor's data pointer is COW"); "Checks if a tensor's data pointer is COW");
py_module.def(
"_get_cudnn_batch_norm_reserve_space_size",
[](const at::Tensor& input, bool training) {
#ifdef USE_CUDA
return at::native::_get_cudnn_batch_norm_reserve_space_size(
input, training);
#else
TORCH_CHECK(false, "PyTorch was not built with cuda");
#endif
},
py::arg("input"),
py::arg("training"));
py::enum_<at::native::BatchNormBackend>(py_module, "_BatchNormBackend")
.value("Native", at::native::BatchNormBackend::Native)
.value("Cudnn", at::native::BatchNormBackend::Cudnn)
.value("Miopen", at::native::BatchNormBackend::Miopen);
py_module.def(
"_select_batch_norm_backend",
[](const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_var,
bool training,
double eps) {
return at::native::_select_batch_norm_backend(
input, weight, bias, running_mean, running_var, training, eps);
},
py::arg("input"),
py::arg("weight"),
py::arg("bias"),
py::arg("running_mean"),
py::arg("running_var"),
py::arg("training"),
py::arg("eps"));
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator = THPDefaultCPUGenerator =
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);

View File

@ -3312,7 +3312,6 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
{"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"}, {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"}, {"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"}, {"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "native_batch_norm"},
{"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"}, {"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
{"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"}, {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"}, {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},

View File

@ -1430,11 +1430,6 @@ add_shape_compute_mapping(
"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
native_batch_norm, native_batch_norm,
) )
add_shape_compute_mapping(
"_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
native_batch_norm,
)
add_shape_compute_mapping( add_shape_compute_mapping(
"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
cross_entropy_loss, cross_entropy_loss,

View File

@ -509,19 +509,6 @@ def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad
else: else:
yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps)) yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs):
samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
for sample in samples:
# torch.native_batch_norm does not support 0 numel tensors
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
if sample.input.numel() == 0:
continue
args = sample.args
momentum = sample.kwargs.get('momentum', 0.5)
eps = sample.kwargs.get('eps', 1e-5)
if any(args[i] is None for i in range(4)):
continue
yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps))
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -13111,42 +13098,6 @@ op_db: List[OpInfo] = [
"TestCompositeCompliance", "test_forward_ad"), "TestCompositeCompliance", "test_forward_ad"),
) )
), ),
OpInfo('_batch_norm_with_update',
op=torch.ops.aten._batch_norm_with_update,
aten_name='_batch_norm_with_update',
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
# TODO: Avoid COW materialize
supports_cow_input_no_materialize=False,
sample_inputs_func=sample_inputs__batch_norm_with_update,
skips=(
# NotImplementedError: Could not run
# 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
# RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
# Problem with _get_numerical_jacobian
# IndexError: tuple index out of range
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
# https://github.com/pytorch/pytorch/issues/85960
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
"TestCompositeCompliance", "test_forward_ad"),
# _batch_norm_with_update expects contiguous inputs for cudnn and miopen
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"),
DecorateInfo(unittest.expectedFailure,
'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"),
# _batch_norm_with_update does not have python bindings
DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# aten out variants do not accept out= kwarg, only python out variants
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
)
),
OpInfo('nn.functional.cosine_similarity', OpInfo('nn.functional.cosine_similarity',
aten_name="cosine_similarity", aten_name="cosine_similarity",
dtypes=floating_types_and(torch.half, torch.bfloat16), dtypes=floating_types_and(torch.half, torch.bfloat16),