Compare commits

...

1 Commits

Author SHA1 Message Date
24e35f0c37 Switch batch norm stack to consolidated ops
Summary: This commit switches `aten.batch_norm` to call the
new `batch_norm_with_update` and `batch_norm_no_update` ops,
instead of the old `_batch_norm_impl_index` op. The new stack
is "consolidated" in the sense that there is a single backend
agnostic op that will internally pick the right kernel based
on the backend, but this detail will be hidden away from the
user and from the model graph.

ghstack-source-id: 518baff49c66aeccd4d28a50522fe104bc323d1b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119496
2024-07-24 15:03:37 -07:00
16 changed files with 322 additions and 93 deletions

View File

@ -861,10 +861,21 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_no_
return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
}
static std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _batch_norm_with_update_batch(
const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
at::Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
at::native_batch_norm(self, weight_opt, bias_opt, running_mean, running_var, true/*train*/, momentum, eps);
at::Tensor reserve = at::empty({0}, self.options().dtype(kByte));
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(output, save_mean, save_var, reserve);
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm));
VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm));
VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm));
m.impl("_batch_norm_with_update", _batch_norm_with_update_batch);
m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch);
m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch);
m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward));

View File

@ -485,20 +485,31 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
}
BatchNormBackend _select_batch_norm_backend(
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
const Tensor& running_var, bool training, double eps) {
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
bool training,
double eps) {
auto& ctx = at::globalContext();
bool cudnn_enabled = ctx.userEnabledCuDNN();
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
const bool has_valid_running_stats = (
(has_running_mean && has_running_var) ||
(!has_running_mean && !has_running_var && training)
);
if (
input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& has_valid_running_stats
&& (input.dim() >= 3)
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
@ -517,8 +528,7 @@ BatchNormBackend _select_batch_norm_backend(
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& has_valid_running_stats
&& (input.dim() >= 3)
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
@ -669,36 +679,49 @@ Tensor batch_norm(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
bool training, double momentum, double eps, bool cudnn_enabled) {
const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
// TODO: switch to the new stack after the 2 week FC window
// if (training) {
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
// auto input_c = input;
// if (backend == BatchNormBackend::Cudnn) {
// input_c = input.contiguous(input.suggest_memory_format());
// } else {
// input_c = input.contiguous();
// }
// auto weight_c = weight.contiguous();
// auto bias_c = bias.contiguous();
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
// const_cast<Tensor&>(rvar_c), momentum, eps));
// } else {
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
// const_cast<Tensor&>(running_var), momentum, eps));
// }
// } else {
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
// momentum, eps));
// }
const bool running_stats_defined = running_mean.defined() && running_var.defined();
if (input.sym_numel() == 0) {
// don't return view of input, don't return empty tensor because it will break gradient chain
auto out = input.clone();
if (weight.defined()) out = out * weight[0];
if (bias.defined()) out = out + bias[0];
return out;
}
if (!training && !running_stats_defined) {
AT_ERROR("running_mean and running_var must be defined in evaluation mode");
}
if (training && running_stats_defined) {
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
auto input_c = input;
if (backend == BatchNormBackend::Cudnn) {
input_c = input.contiguous(input.suggest_memory_format());
} else {
input_c = input.contiguous();
}
auto weight_c = weight.contiguous();
auto bias_c = bias.contiguous();
auto rmean_c = running_mean.contiguous();
auto rvar_c = running_var.contiguous();
return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
const_cast<Tensor&>(rvar_c), momentum, eps));
} else {
return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
const_cast<Tensor&>(running_var), momentum, eps));
}
} else {
return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
momentum, eps));
}
}
Tensor instance_norm(
@ -869,14 +892,21 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_cpu(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
double momentum, double eps) {
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
// Users may disable updating running stats during training by passing in None for
// these stats. In such cases, we go through `_batch_norm_no_update` instead of
// `_batch_norm_with_update` because the latter's schema expects defined running
// stats. Therefore, here we support both eval and training paths, using the eval
// path only if both running stats are defined. Otherwise, passing in undefined
// Tensors to `batch_norm_cpu` in eval mode would lead to seg fault.
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
const bool update = !has_running_mean || !has_running_var;
auto [output, save_mean, save_var] =
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
batch_norm_cpu(input, weight_opt, bias_opt, running_mean_opt, running_var_opt, update, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
@ -909,11 +939,11 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
}
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
return batch_norm_backward_cpu(grad_output, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
}
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,

View File

@ -14,6 +14,6 @@ enum class BatchNormBackend {
Miopen,
};
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const std::optional<Tensor>& running_mean, const std::optional<Tensor>& running_var, bool training, double eps);
} // namespace at::native

View File

@ -15,6 +15,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_batch_norm_no_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/batch_norm_backward_elemt_native.h>
#include <ATen/ops/batch_norm_backward_reduce_native.h>
@ -481,31 +482,40 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const std
return std::make_tuple(output, save_mean, save_invstd);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_maybe_update_cuda_helper(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
double momentum, double eps, bool update) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
Tensor output, save_mean, save_var, reserve;
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, update, eps);
if (backend == BatchNormBackend::Cudnn) {
std::tie(output, save_mean, save_var, reserve) =
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, update, momentum, eps);
} else if (backend == BatchNormBackend::Miopen) {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, update, momentum, eps);
} else {
reserve = at::empty({0}, input.options().dtype(kByte));
std::tie(output, save_mean, save_var) =
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, update, momentum, eps);
}
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
return _batch_norm_maybe_update_cuda_helper(input, weight_opt, bias_opt, running_mean, running_var, momentum, eps, /*update*/true);
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
@ -529,6 +539,16 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_cuda(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
double momentum, double eps) {
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
const bool train = !has_running_mean || !has_running_var;
return _batch_norm_maybe_update_cuda_helper(input, weight_opt, bias_opt, running_mean_opt, running_var_opt, momentum, eps, train);
}
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
}
@ -546,21 +566,24 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
}
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& dummy_bias = at::empty(1);
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, update, eps);
if (backend == BatchNormBackend::Cudnn) {
if (backend == BatchNormBackend::Cudnn && update) {
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
} else if (backend == BatchNormBackend::Miopen) {
} else if (backend == BatchNormBackend::Miopen && update) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
} else {
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);

View File

@ -7,6 +7,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_batch_norm_no_update_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/_to_dense_native.h>
@ -68,8 +69,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_mkldnn(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean, const c10::optional<Tensor>& running_var, double momentum, double eps) {
TORCH_CHECK(false, "_batch_norm_no_update_mkldnn: ATen not compiled with MKLDNN support");
}
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
@ -218,6 +225,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_mkldnn(
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
const c10::optional<Tensor>& running_mean, const c10::optional<Tensor>& running_var,
double momentum, double eps) {
const bool has_running_mean = running_mean.has_value() && running_mean->defined();
const bool has_running_var = running_var.has_value() && running_var->defined();
const bool train = !has_running_mean || !has_running_var;
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
Tensor reserve = empty_mkldnn({0}, input.scalar_type());
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
bool train,
@ -237,11 +259,11 @@ std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
return mkldnn_batch_norm_backward(grad_output, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
}

View File

@ -10,6 +10,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_batch_norm_no_update_native.h>
#include <ATen/ops/_batch_norm_with_update_native.h>
#include <ATen/ops/_native_batch_norm_legit_native.h>
#include <ATen/ops/batch_norm_backward_native.h>
@ -437,6 +438,22 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_mps_out(c
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_mps(const Tensor& input,
const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
double momentum,
double eps) {
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
const bool train = !has_running_mean || !has_running_var;
auto [output, save_mean, save_var] =
batch_norm_mps(input, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
}
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& bias_opt,
@ -504,7 +521,7 @@ static string get_mem_string(c10::MemoryFormat memory_format) {
// Batch norm backward
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt,
@ -515,7 +532,7 @@ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& gr
const Tensor& reserve) {
return batch_norm_backward_mps(grad_output,
input,
weight,
weight_opt,
running_mean_opt,
running_var_opt,
save_mean_opt,

View File

@ -6620,10 +6620,13 @@
- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CompositeExplicitAutograd: _batch_norm_no_update
CPU: _batch_norm_no_update_cpu
CUDA: _batch_norm_no_update_cuda
MPS: _batch_norm_no_update_mps
MkldnnCPU: _batch_norm_no_update_mkldnn
autogen: _batch_norm_no_update.out
- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
dispatch:
CPU: _new_batch_norm_backward_cpu
CUDA: _new_batch_norm_backward_cuda

View File

@ -3114,8 +3114,8 @@ def forward(self, x):
bn_running_mean = self.bn.running_mean
bn_running_var = self.bn.running_var
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
_native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
getitem = _native_batch_norm_legit_no_training[0]; _native_batch_norm_legit_no_training = None
_batch_norm_no_update = torch.ops.aten._batch_norm_no_update.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
getitem = _batch_norm_no_update[0]; _batch_norm_no_update = None
return pytree.tree_unflatten((getitem,), self._out_spec)""",
)
@ -3135,12 +3135,12 @@ def forward(self, x):
bn_num_batches_tracked = self.bn.num_batches_tracked
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
add = torch.ops.aten.add.Tensor(bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05); conv2d = bn_weight = bn_bias = None
getitem = _native_batch_norm_legit_functional[0]
getitem_3 = _native_batch_norm_legit_functional[3]
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_3); bn_running_mean = getitem_3 = None
copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_4); bn_running_var = getitem_4 = None
_batch_norm_with_update_functional = torch.ops.aten._batch_norm_with_update_functional.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = None
getitem = _batch_norm_with_update_functional[0]
getitem_4 = _batch_norm_with_update_functional[4]
getitem_5 = _batch_norm_with_update_functional[5]; _batch_norm_with_update_functional = None
copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_4); bn_running_mean = getitem_4 = None
copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_5); bn_running_var = getitem_5 = None
copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add); bn_num_batches_tracked = add = None
return pytree.tree_unflatten((getitem,), self._out_spec)""",
)

View File

@ -258,13 +258,7 @@ class PT2EQATTestCase(QuantizationTestCase):
relu_node = None
getitem_node = output_fq_node.args[0]
bn_node = getitem_node.args[0]
if is_cuda:
if torch.version.cuda is not None:
expected_bn_op = torch.ops.aten.cudnn_batch_norm.default
elif torch.version.hip is not None:
expected_bn_op = torch.ops.aten.miopen_batch_norm.default
else:
expected_bn_op = torch.ops.aten._native_batch_norm_legit.default
expected_bn_op = torch.ops.aten._batch_norm_with_update.default
self.assertEqual(getitem_node.target, operator.getitem)
self.assertEqual(bn_node.target, expected_bn_op)

View File

@ -2654,6 +2654,29 @@ class TestFakeTensor(TestCase):
self.assertEqual(strided_result.layout, torch.strided)
class TestBatchNorm(TestCase):
def test_batch_norm_undefined_stats(self):
x = torch.randn(1, 3, 2, 1)
weight = torch.randn(3)
bias = torch.randn(3)
rm = torch.randn(3)
rv = torch.randn(3)
# train=True, passes
aten.batch_norm(x, weight, bias, None, None, True, 0.5, 0.6, True)
aten.batch_norm(x, weight, bias, rm, None, True, 0.5, 0.6, True)
aten.batch_norm(x, weight, bias, None, rv, True, 0.5, 0.6, True)
# train=False, expected RuntimeError
error_str = "must be defined in evaluation mode"
with self.assertRaisesRegex(RuntimeError, error_str):
aten.batch_norm(x, weight, bias, None, None, False, 0.5, 0.6, True)
with self.assertRaisesRegex(RuntimeError, error_str):
aten.batch_norm(x, weight, bias, rm, None, False, 0.5, 0.6, True)
with self.assertRaisesRegex(RuntimeError, error_str):
aten.batch_norm(x, weight, bias, None, rv, False, 0.5, 0.6, True)
instantiate_device_type_tests(TestCommon, globals())
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())

View File

@ -1271,10 +1271,10 @@
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
input, weight, bias: "grad.defined() ? _batch_norm_no_update_backward_manual(grad, input, weight, running_mean, running_var, result1, result2, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
result0: _batch_norm_no_update_jvp_manual(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, eps)
- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
save_mean: not_implemented("batch_norm_backward save_mean")
save_var: not_implemented("batch_norm_backward save_var")

View File

@ -2026,8 +2026,8 @@ def _get_batch_norm_reserve_tensor(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Tensor,
running_var: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
eps: float,
training: bool,
) -> Tensor:
@ -2108,18 +2108,26 @@ def _batch_norm_no_update(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Tensor,
running_var: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# Users may disable updating running stats during training
# by passing in None for these stats. In such cases, we go
# through `_batch_norm_no_update` instead of `_batch_norm_with_update`
# because the latter's schema expects defined running stats.
# Therefore, here we support both eval and training paths,
# using the eval path only if both running stats are defined.
training = running_mean is None or running_var is None
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
bias,
running_mean,
running_var,
False, # training
training,
momentum,
eps,
False, # functional

View File

@ -198,7 +198,10 @@ def _is_conv_transpose_fn(conv_fn: Callable):
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
def _is_bn_node(n: Node):
return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
return n.target in [
torch.ops.aten._batch_norm_no_update.default,
torch.ops.aten._batch_norm_with_update.default,
]
def fold_bn_weights_into_conv_node(
conv_node: Node,
@ -220,13 +223,7 @@ def fold_bn_weights_into_conv_node(
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
eps_arg_index = 6
elif _is_supported_batch_norm_for_training(bn_node):
eps_arg_index = 7
else:
raise ValueError("BN node target is unexpected ", bn_node.target)
bn_eps = bn_args[eps_arg_index]
bn_eps = bn_args[6]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)

View File

@ -2273,14 +2273,22 @@ Call this whenever a new thread is created in order to propagate values from
py_module.def(
"_select_batch_norm_backend",
[](const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_var,
const std::optional<at::Tensor>& weight,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& running_mean,
const std::optional<at::Tensor>& running_var,
bool training,
double eps) {
auto weight_opt = weight.has_value() ? weight.value() : at::empty({});
auto bias_opt = bias.has_value() ? bias.value() : at::empty({});
return at::native::_select_batch_norm_backend(
input, weight, bias, running_mean, running_var, training, eps);
input,
weight_opt,
bias_opt,
running_mean,
running_var,
training,
eps);
},
py::arg("input"),
py::arg("weight"),

View File

@ -7146,6 +7146,74 @@ Tensor values_backward(const Tensor& grad, const Tensor& self) {
return grad_self;
}
/**
* A version of `batch_norm_jvp` used only for `_batch_norm_no_update` that
* sets the `train` flag based on whether the running stats are defined,
* consistent with the behavior in `_batch_norm_no_update`.
*/
Tensor _batch_norm_no_update_jvp_manual(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
const std::optional<Tensor>& running_mean,
const std::optional<Tensor>& running_var,
const Tensor& saved_mean,
const Tensor& saved_invstd,
double eps) {
const bool has_running_mean = running_mean.has_value() && running_mean->defined();
const bool has_running_var = running_var.has_value() && running_var->defined();
const bool train = !has_running_mean || !has_running_var;
return batch_norm_jvp(
input_p,
input_t,
weight_p,
weight_t,
bias_p,
bias_t,
running_mean,
running_var,
saved_mean,
saved_invstd,
train,
eps);
}
/**
* A version of `batch_norm_backward` used only for `_batch_norm_no_update`
* that sets the `train` flag based on whether the running stats are defined,
* consistent with the behavior in `_batch_norm_no_update`.
*/
std::tuple<Tensor, Tensor, Tensor> _batch_norm_no_update_backward_manual(
const Tensor& grad_output,
const Tensor& input,
const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt,
const std::optional<Tensor>& save_var_opt,
double eps,
std::array<bool,3> grad_input_mask,
const Tensor& reserve) {
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
const bool train = !has_running_mean || !has_running_var;
return at::batch_norm_backward(
grad_output,
input,
weight_opt,
running_mean_opt,
running_var_opt,
save_mean_opt,
save_var_opt,
train,
eps,
grad_input_mask,
reserve);
}
} // namespace details
} // namespace generated
} // namespace autograd

View File

@ -1102,4 +1102,29 @@ mkldnn_rnn_layer_differentiable_backward(
Tensor values_backward(const Tensor& grad, const Tensor& self);
Tensor _batch_norm_no_update_jvp_manual(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& bias_p,
const Tensor& bias_t,
const std::optional<Tensor>& running_mean,
const std::optional<Tensor>& running_var,
const Tensor& saved_mean,
const Tensor& saved_invstd,
double eps);
std::tuple<Tensor, Tensor, Tensor> _batch_norm_no_update_backward_manual(
const Tensor& grad_output,
const Tensor& input,
const std::optional<Tensor>& weight_opt,
const std::optional<Tensor>& running_mean_opt,
const std::optional<Tensor>& running_var_opt,
const std::optional<Tensor>& save_mean_opt,
const std::optional<Tensor>& save_var_opt,
double eps,
std::array<bool,3> grad_input_mask,
const Tensor& reserve);
} // namespace torch::autograd::generated::details