mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
Compare commits
1 Commits
v2.7.1-rc2
...
switch-bn
| Author | SHA1 | Date | |
|---|---|---|---|
| 24e35f0c37 |
@ -861,10 +861,21 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> _native_batch_norm_legit_no_
|
|||||||
return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
|
return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> _batch_norm_with_update_batch(
|
||||||
|
const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
||||||
|
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
||||||
|
at::Tensor output, save_mean, save_var;
|
||||||
|
std::tie(output, save_mean, save_var) =
|
||||||
|
at::native_batch_norm(self, weight_opt, bias_opt, running_mean, running_var, true/*train*/, momentum, eps);
|
||||||
|
at::Tensor reserve = at::empty({0}, self.options().dtype(kByte));
|
||||||
|
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(output, save_mean, save_var, reserve);
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||||
VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm));
|
VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm));
|
||||||
VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm));
|
VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm));
|
||||||
VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm));
|
VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm));
|
||||||
|
m.impl("_batch_norm_with_update", _batch_norm_with_update_batch);
|
||||||
m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch);
|
m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch);
|
||||||
m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch);
|
m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch);
|
||||||
m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward));
|
m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward));
|
||||||
|
|||||||
@ -485,20 +485,31 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
|
|||||||
}
|
}
|
||||||
|
|
||||||
BatchNormBackend _select_batch_norm_backend(
|
BatchNormBackend _select_batch_norm_backend(
|
||||||
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
|
const Tensor& input,
|
||||||
const Tensor& running_var, bool training, double eps) {
|
const Tensor& weight,
|
||||||
|
const Tensor& bias,
|
||||||
|
const std::optional<Tensor>& running_mean_opt,
|
||||||
|
const std::optional<Tensor>& running_var_opt,
|
||||||
|
bool training,
|
||||||
|
double eps) {
|
||||||
|
|
||||||
auto& ctx = at::globalContext();
|
auto& ctx = at::globalContext();
|
||||||
bool cudnn_enabled = ctx.userEnabledCuDNN();
|
bool cudnn_enabled = ctx.userEnabledCuDNN();
|
||||||
|
|
||||||
|
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
|
||||||
|
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
|
||||||
|
const bool has_valid_running_stats = (
|
||||||
|
(has_running_mean && has_running_var) ||
|
||||||
|
(!has_running_mean && !has_running_var && training)
|
||||||
|
);
|
||||||
|
|
||||||
if (
|
if (
|
||||||
input.is_cuda()
|
input.is_cuda()
|
||||||
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
|
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
|
||||||
&& (input.scalar_type() != at::kHalf
|
&& (input.scalar_type() != at::kHalf
|
||||||
|| weight.scalar_type() == at::kFloat)
|
|| weight.scalar_type() == at::kFloat)
|
||||||
&& weight.defined() && bias.defined()
|
&& weight.defined() && bias.defined()
|
||||||
&& ((running_mean.defined() && running_var.defined())
|
&& has_valid_running_stats
|
||||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
|
||||||
&& (input.dim() >= 3)
|
&& (input.dim() >= 3)
|
||||||
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
|
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
|
||||||
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
|
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
|
||||||
@ -517,8 +528,7 @@ BatchNormBackend _select_batch_norm_backend(
|
|||||||
&& input.scalar_type() != at::kBFloat16
|
&& input.scalar_type() != at::kBFloat16
|
||||||
&& (weight.scalar_type() != at::kHalf)
|
&& (weight.scalar_type() != at::kHalf)
|
||||||
&& weight.defined() && bias.defined()
|
&& weight.defined() && bias.defined()
|
||||||
&& ((running_mean.defined() && running_var.defined())
|
&& has_valid_running_stats
|
||||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
|
||||||
&& (input.dim() >= 3)
|
&& (input.dim() >= 3)
|
||||||
&& detail::getCUDAHooks().compiledWithMIOpen()
|
&& detail::getCUDAHooks().compiledWithMIOpen()
|
||||||
&& cudnn_enabled
|
&& cudnn_enabled
|
||||||
@ -669,36 +679,49 @@ Tensor batch_norm(
|
|||||||
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
bool training, double momentum, double eps, bool cudnn_enabled) {
|
bool training, double momentum, double eps, bool cudnn_enabled) {
|
||||||
const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
||||||
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
||||||
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
|
const bool running_stats_defined = running_mean.defined() && running_var.defined();
|
||||||
training, momentum, eps, cudnn_enabled));
|
|
||||||
// TODO: switch to the new stack after the 2 week FC window
|
if (input.sym_numel() == 0) {
|
||||||
// if (training) {
|
// don't return view of input, don't return empty tensor because it will break gradient chain
|
||||||
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
|
auto out = input.clone();
|
||||||
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
|
if (weight.defined()) out = out * weight[0];
|
||||||
// auto input_c = input;
|
if (bias.defined()) out = out + bias[0];
|
||||||
// if (backend == BatchNormBackend::Cudnn) {
|
return out;
|
||||||
// input_c = input.contiguous(input.suggest_memory_format());
|
}
|
||||||
// } else {
|
|
||||||
// input_c = input.contiguous();
|
if (!training && !running_stats_defined) {
|
||||||
// }
|
AT_ERROR("running_mean and running_var must be defined in evaluation mode");
|
||||||
// auto weight_c = weight.contiguous();
|
}
|
||||||
// auto bias_c = bias.contiguous();
|
|
||||||
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
|
if (training && running_stats_defined) {
|
||||||
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
|
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
|
||||||
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
|
if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
|
||||||
// const_cast<Tensor&>(rvar_c), momentum, eps));
|
auto input_c = input;
|
||||||
// } else {
|
if (backend == BatchNormBackend::Cudnn) {
|
||||||
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
|
input_c = input.contiguous(input.suggest_memory_format());
|
||||||
// const_cast<Tensor&>(running_var), momentum, eps));
|
} else {
|
||||||
// }
|
input_c = input.contiguous();
|
||||||
// } else {
|
}
|
||||||
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
|
auto weight_c = weight.contiguous();
|
||||||
// momentum, eps));
|
auto bias_c = bias.contiguous();
|
||||||
// }
|
auto rmean_c = running_mean.contiguous();
|
||||||
|
auto rvar_c = running_var.contiguous();
|
||||||
|
return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
|
||||||
|
const_cast<Tensor&>(rvar_c), momentum, eps));
|
||||||
|
} else {
|
||||||
|
return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
|
||||||
|
const_cast<Tensor&>(running_var), momentum, eps));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
|
||||||
|
momentum, eps));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor instance_norm(
|
Tensor instance_norm(
|
||||||
@ -869,14 +892,21 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_cpu(
|
||||||
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
double momentum, double eps) {
|
double momentum, double eps) {
|
||||||
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
// Users may disable updating running stats during training by passing in None for
|
||||||
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
// these stats. In such cases, we go through `_batch_norm_no_update` instead of
|
||||||
|
// `_batch_norm_with_update` because the latter's schema expects defined running
|
||||||
|
// stats. Therefore, here we support both eval and training paths, using the eval
|
||||||
|
// path only if both running stats are defined. Otherwise, passing in undefined
|
||||||
|
// Tensors to `batch_norm_cpu` in eval mode would lead to seg fault.
|
||||||
|
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
|
||||||
|
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
|
||||||
|
const bool update = !has_running_mean || !has_running_var;
|
||||||
auto [output, save_mean, save_var] =
|
auto [output, save_mean, save_var] =
|
||||||
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
|
batch_norm_cpu(input, weight_opt, bias_opt, running_mean_opt, running_var_opt, update, momentum, eps);
|
||||||
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
||||||
}
|
}
|
||||||
@ -909,11 +939,11 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
|
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
||||||
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
|
return batch_norm_backward_cpu(grad_output, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
|
||||||
|
|||||||
@ -14,6 +14,6 @@ enum class BatchNormBackend {
|
|||||||
Miopen,
|
Miopen,
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
|
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const std::optional<Tensor>& running_mean, const std::optional<Tensor>& running_var, bool training, double eps);
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
#include <ATen/ops/_batch_norm_with_update_native.h>
|
||||||
|
#include <ATen/ops/_batch_norm_no_update_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
#include <ATen/ops/batch_norm_backward_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_elemt_native.h>
|
#include <ATen/ops/batch_norm_backward_elemt_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_reduce_native.h>
|
#include <ATen/ops/batch_norm_backward_reduce_native.h>
|
||||||
@ -481,31 +482,40 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const std
|
|||||||
return std::make_tuple(output, save_mean, save_invstd);
|
return std::make_tuple(output, save_mean, save_invstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_maybe_update_cuda_helper(
|
||||||
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
|
double momentum, double eps, bool update) {
|
||||||
// See [Note: hacky wrapper removal for optional tensor]
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
const Tensor& weight = *weight_maybe_owned;
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
||||||
|
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
||||||
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
Tensor output, save_mean, save_var, reserve;
|
Tensor output, save_mean, save_var, reserve;
|
||||||
|
|
||||||
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
|
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, update, eps);
|
||||||
if (backend == BatchNormBackend::Cudnn) {
|
if (backend == BatchNormBackend::Cudnn) {
|
||||||
std::tie(output, save_mean, save_var, reserve) =
|
std::tie(output, save_mean, save_var, reserve) =
|
||||||
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
|
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, update, momentum, eps);
|
||||||
} else if (backend == BatchNormBackend::Miopen) {
|
} else if (backend == BatchNormBackend::Miopen) {
|
||||||
reserve = at::empty({0}, input.options().dtype(kByte));
|
reserve = at::empty({0}, input.options().dtype(kByte));
|
||||||
std::tie(output, save_mean, save_var) =
|
std::tie(output, save_mean, save_var) =
|
||||||
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
|
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, update, momentum, eps);
|
||||||
} else {
|
} else {
|
||||||
reserve = at::empty({0}, input.options().dtype(kByte));
|
reserve = at::empty({0}, input.options().dtype(kByte));
|
||||||
std::tie(output, save_mean, save_var) =
|
std::tie(output, save_mean, save_var) =
|
||||||
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
|
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, update, momentum, eps);
|
||||||
}
|
}
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
|
||||||
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
|
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
||||||
|
return _batch_norm_maybe_update_cuda_helper(input, weight_opt, bias_opt, running_mean, running_var, momentum, eps, /*update*/true);
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
|
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
|
||||||
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
|
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
|
||||||
@ -529,6 +539,16 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
|
|||||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_cuda(
|
||||||
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
|
||||||
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
|
double momentum, double eps) {
|
||||||
|
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
|
||||||
|
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
|
||||||
|
const bool train = !has_running_mean || !has_running_var;
|
||||||
|
return _batch_norm_maybe_update_cuda_helper(input, weight_opt, bias_opt, running_mean_opt, running_var_opt, momentum, eps, train);
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
|
||||||
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
|
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
|
||||||
}
|
}
|
||||||
@ -546,21 +566,24 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
|
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
||||||
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
|
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
const Tensor& weight = *weight_maybe_owned;
|
||||||
const Tensor& dummy_bias = at::empty(1);
|
const Tensor& dummy_bias = at::empty(1);
|
||||||
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
||||||
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
|
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
|
||||||
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
|
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
|
||||||
|
|
||||||
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
|
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, update, eps);
|
||||||
|
|
||||||
if (backend == BatchNormBackend::Cudnn) {
|
if (backend == BatchNormBackend::Cudnn && update) {
|
||||||
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
|
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
|
||||||
} else if (backend == BatchNormBackend::Miopen) {
|
} else if (backend == BatchNormBackend::Miopen && update) {
|
||||||
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
|
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
|
||||||
} else {
|
} else {
|
||||||
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
|
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
#include <ATen/ops/_batch_norm_with_update_native.h>
|
||||||
|
#include <ATen/ops/_batch_norm_no_update_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
#include <ATen/ops/batch_norm_backward_native.h>
|
||||||
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
||||||
#include <ATen/ops/_to_dense_native.h>
|
#include <ATen/ops/_to_dense_native.h>
|
||||||
@ -68,8 +69,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
|
|||||||
TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
|
TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_mkldnn(
|
||||||
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
|
const c10::optional<Tensor>& running_mean, const c10::optional<Tensor>& running_var, double momentum, double eps) {
|
||||||
|
TORCH_CHECK(false, "_batch_norm_no_update_mkldnn: ATen not compiled with MKLDNN support");
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
||||||
@ -218,6 +225,21 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_mkldnn(
|
||||||
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
|
const c10::optional<Tensor>& running_mean, const c10::optional<Tensor>& running_var,
|
||||||
|
double momentum, double eps) {
|
||||||
|
const bool has_running_mean = running_mean.has_value() && running_mean->defined();
|
||||||
|
const bool has_running_var = running_var.has_value() && running_var->defined();
|
||||||
|
const bool train = !has_running_mean || !has_running_var;
|
||||||
|
Tensor output, save_mean, save_var;
|
||||||
|
std::tie(output, save_mean, save_var) =
|
||||||
|
mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
|
||||||
|
Tensor reserve = empty_mkldnn({0}, input.scalar_type());
|
||||||
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
|
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
|
||||||
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
|
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
|
||||||
bool train,
|
bool train,
|
||||||
@ -237,11 +259,11 @@ std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
|
|||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
const Tensor& grad_output, const Tensor& input, const std::optional<Tensor>& weight_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
|
||||||
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
||||||
return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
|
return mkldnn_batch_norm_backward(grad_output, input, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/ops/_batch_norm_no_update_native.h>
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
#include <ATen/ops/_batch_norm_with_update_native.h>
|
||||||
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
#include <ATen/ops/batch_norm_backward_native.h>
|
||||||
@ -437,6 +438,22 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_mps_out(c
|
|||||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update_mps(const Tensor& input,
|
||||||
|
const std::optional<Tensor>& weight_opt,
|
||||||
|
const std::optional<Tensor>& bias_opt,
|
||||||
|
const std::optional<Tensor>& running_mean_opt,
|
||||||
|
const std::optional<Tensor>& running_var_opt,
|
||||||
|
double momentum,
|
||||||
|
double eps) {
|
||||||
|
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
|
||||||
|
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
|
||||||
|
const bool train = !has_running_mean || !has_running_var;
|
||||||
|
auto [output, save_mean, save_var] =
|
||||||
|
batch_norm_mps(input, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps);
|
||||||
|
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
||||||
|
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
|
||||||
const std::optional<Tensor>& weight_opt,
|
const std::optional<Tensor>& weight_opt,
|
||||||
const std::optional<Tensor>& bias_opt,
|
const std::optional<Tensor>& bias_opt,
|
||||||
@ -504,7 +521,7 @@ static string get_mem_string(c10::MemoryFormat memory_format) {
|
|||||||
// Batch norm backward
|
// Batch norm backward
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
|
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
const Tensor& weight,
|
const std::optional<Tensor>& weight_opt,
|
||||||
const std::optional<Tensor>& running_mean_opt,
|
const std::optional<Tensor>& running_mean_opt,
|
||||||
const std::optional<Tensor>& running_var_opt,
|
const std::optional<Tensor>& running_var_opt,
|
||||||
const std::optional<Tensor>& save_mean_opt,
|
const std::optional<Tensor>& save_mean_opt,
|
||||||
@ -515,7 +532,7 @@ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& gr
|
|||||||
const Tensor& reserve) {
|
const Tensor& reserve) {
|
||||||
return batch_norm_backward_mps(grad_output,
|
return batch_norm_backward_mps(grad_output,
|
||||||
input,
|
input,
|
||||||
weight,
|
weight_opt,
|
||||||
running_mean_opt,
|
running_mean_opt,
|
||||||
running_var_opt,
|
running_var_opt,
|
||||||
save_mean_opt,
|
save_mean_opt,
|
||||||
|
|||||||
@ -6620,10 +6620,13 @@
|
|||||||
|
|
||||||
- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: _batch_norm_no_update
|
CPU: _batch_norm_no_update_cpu
|
||||||
|
CUDA: _batch_norm_no_update_cuda
|
||||||
|
MPS: _batch_norm_no_update_mps
|
||||||
|
MkldnnCPU: _batch_norm_no_update_mkldnn
|
||||||
autogen: _batch_norm_no_update.out
|
autogen: _batch_norm_no_update.out
|
||||||
|
|
||||||
- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
|
- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: _new_batch_norm_backward_cpu
|
CPU: _new_batch_norm_backward_cpu
|
||||||
CUDA: _new_batch_norm_backward_cuda
|
CUDA: _new_batch_norm_backward_cuda
|
||||||
|
|||||||
@ -3114,8 +3114,8 @@ def forward(self, x):
|
|||||||
bn_running_mean = self.bn.running_mean
|
bn_running_mean = self.bn.running_mean
|
||||||
bn_running_var = self.bn.running_var
|
bn_running_var = self.bn.running_var
|
||||||
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
|
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
|
||||||
_native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
|
_batch_norm_no_update = torch.ops.aten._batch_norm_no_update.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
|
||||||
getitem = _native_batch_norm_legit_no_training[0]; _native_batch_norm_legit_no_training = None
|
getitem = _batch_norm_no_update[0]; _batch_norm_no_update = None
|
||||||
return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3135,12 +3135,12 @@ def forward(self, x):
|
|||||||
bn_num_batches_tracked = self.bn.num_batches_tracked
|
bn_num_batches_tracked = self.bn.num_batches_tracked
|
||||||
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
|
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
|
||||||
add = torch.ops.aten.add.Tensor(bn_num_batches_tracked, 1)
|
add = torch.ops.aten.add.Tensor(bn_num_batches_tracked, 1)
|
||||||
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05); conv2d = bn_weight = bn_bias = None
|
_batch_norm_with_update_functional = torch.ops.aten._batch_norm_with_update_functional.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = None
|
||||||
getitem = _native_batch_norm_legit_functional[0]
|
getitem = _batch_norm_with_update_functional[0]
|
||||||
getitem_3 = _native_batch_norm_legit_functional[3]
|
getitem_4 = _batch_norm_with_update_functional[4]
|
||||||
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
|
getitem_5 = _batch_norm_with_update_functional[5]; _batch_norm_with_update_functional = None
|
||||||
copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_3); bn_running_mean = getitem_3 = None
|
copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_4); bn_running_mean = getitem_4 = None
|
||||||
copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_4); bn_running_var = getitem_4 = None
|
copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_5); bn_running_var = getitem_5 = None
|
||||||
copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add); bn_num_batches_tracked = add = None
|
copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add); bn_num_batches_tracked = add = None
|
||||||
return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -258,13 +258,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
|||||||
relu_node = None
|
relu_node = None
|
||||||
getitem_node = output_fq_node.args[0]
|
getitem_node = output_fq_node.args[0]
|
||||||
bn_node = getitem_node.args[0]
|
bn_node = getitem_node.args[0]
|
||||||
if is_cuda:
|
expected_bn_op = torch.ops.aten._batch_norm_with_update.default
|
||||||
if torch.version.cuda is not None:
|
|
||||||
expected_bn_op = torch.ops.aten.cudnn_batch_norm.default
|
|
||||||
elif torch.version.hip is not None:
|
|
||||||
expected_bn_op = torch.ops.aten.miopen_batch_norm.default
|
|
||||||
else:
|
|
||||||
expected_bn_op = torch.ops.aten._native_batch_norm_legit.default
|
|
||||||
self.assertEqual(getitem_node.target, operator.getitem)
|
self.assertEqual(getitem_node.target, operator.getitem)
|
||||||
self.assertEqual(bn_node.target, expected_bn_op)
|
self.assertEqual(bn_node.target, expected_bn_op)
|
||||||
|
|
||||||
|
|||||||
@ -2654,6 +2654,29 @@ class TestFakeTensor(TestCase):
|
|||||||
self.assertEqual(strided_result.layout, torch.strided)
|
self.assertEqual(strided_result.layout, torch.strided)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchNorm(TestCase):
|
||||||
|
def test_batch_norm_undefined_stats(self):
|
||||||
|
x = torch.randn(1, 3, 2, 1)
|
||||||
|
weight = torch.randn(3)
|
||||||
|
bias = torch.randn(3)
|
||||||
|
rm = torch.randn(3)
|
||||||
|
rv = torch.randn(3)
|
||||||
|
|
||||||
|
# train=True, passes
|
||||||
|
aten.batch_norm(x, weight, bias, None, None, True, 0.5, 0.6, True)
|
||||||
|
aten.batch_norm(x, weight, bias, rm, None, True, 0.5, 0.6, True)
|
||||||
|
aten.batch_norm(x, weight, bias, None, rv, True, 0.5, 0.6, True)
|
||||||
|
|
||||||
|
# train=False, expected RuntimeError
|
||||||
|
error_str = "must be defined in evaluation mode"
|
||||||
|
with self.assertRaisesRegex(RuntimeError, error_str):
|
||||||
|
aten.batch_norm(x, weight, bias, None, None, False, 0.5, 0.6, True)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, error_str):
|
||||||
|
aten.batch_norm(x, weight, bias, rm, None, False, 0.5, 0.6, True)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, error_str):
|
||||||
|
aten.batch_norm(x, weight, bias, None, rv, False, 0.5, 0.6, True)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestCommon, globals())
|
instantiate_device_type_tests(TestCommon, globals())
|
||||||
instantiate_device_type_tests(TestCompositeCompliance, globals())
|
instantiate_device_type_tests(TestCompositeCompliance, globals())
|
||||||
instantiate_device_type_tests(TestMathBits, globals())
|
instantiate_device_type_tests(TestMathBits, globals())
|
||||||
|
|||||||
@ -1271,10 +1271,10 @@
|
|||||||
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
|
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
|
||||||
|
|
||||||
- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
||||||
input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
|
input, weight, bias: "grad.defined() ? _batch_norm_no_update_backward_manual(grad, input, weight, running_mean, running_var, result1, result2, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||||
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
|
result0: _batch_norm_no_update_jvp_manual(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, eps)
|
||||||
|
|
||||||
- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
|
- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
|
||||||
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
|
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
|
||||||
save_mean: not_implemented("batch_norm_backward save_mean")
|
save_mean: not_implemented("batch_norm_backward save_mean")
|
||||||
save_var: not_implemented("batch_norm_backward save_var")
|
save_var: not_implemented("batch_norm_backward save_var")
|
||||||
|
|||||||
@ -2026,8 +2026,8 @@ def _get_batch_norm_reserve_tensor(
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
weight: Optional[Tensor],
|
weight: Optional[Tensor],
|
||||||
bias: Optional[Tensor],
|
bias: Optional[Tensor],
|
||||||
running_mean: Tensor,
|
running_mean: Optional[Tensor],
|
||||||
running_var: Tensor,
|
running_var: Optional[Tensor],
|
||||||
eps: float,
|
eps: float,
|
||||||
training: bool,
|
training: bool,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@ -2108,18 +2108,26 @@ def _batch_norm_no_update(
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
weight: Optional[Tensor],
|
weight: Optional[Tensor],
|
||||||
bias: Optional[Tensor],
|
bias: Optional[Tensor],
|
||||||
running_mean: Tensor,
|
running_mean: Optional[Tensor],
|
||||||
running_var: Tensor,
|
running_var: Optional[Tensor],
|
||||||
momentum: float,
|
momentum: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||||
|
# Users may disable updating running stats during training
|
||||||
|
# by passing in None for these stats. In such cases, we go
|
||||||
|
# through `_batch_norm_no_update` instead of `_batch_norm_with_update`
|
||||||
|
# because the latter's schema expects defined running stats.
|
||||||
|
# Therefore, here we support both eval and training paths,
|
||||||
|
# using the eval path only if both running stats are defined.
|
||||||
|
training = running_mean is None or running_var is None
|
||||||
|
|
||||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
running_mean,
|
running_mean,
|
||||||
running_var,
|
running_var,
|
||||||
False, # training
|
training,
|
||||||
momentum,
|
momentum,
|
||||||
eps,
|
eps,
|
||||||
False, # functional
|
False, # functional
|
||||||
|
|||||||
@ -198,7 +198,10 @@ def _is_conv_transpose_fn(conv_fn: Callable):
|
|||||||
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
|
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
|
||||||
|
|
||||||
def _is_bn_node(n: Node):
|
def _is_bn_node(n: Node):
|
||||||
return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
|
return n.target in [
|
||||||
|
torch.ops.aten._batch_norm_no_update.default,
|
||||||
|
torch.ops.aten._batch_norm_with_update.default,
|
||||||
|
]
|
||||||
|
|
||||||
def fold_bn_weights_into_conv_node(
|
def fold_bn_weights_into_conv_node(
|
||||||
conv_node: Node,
|
conv_node: Node,
|
||||||
@ -220,13 +223,7 @@ def fold_bn_weights_into_conv_node(
|
|||||||
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
||||||
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
||||||
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
||||||
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
|
bn_eps = bn_args[6]
|
||||||
eps_arg_index = 6
|
|
||||||
elif _is_supported_batch_norm_for_training(bn_node):
|
|
||||||
eps_arg_index = 7
|
|
||||||
else:
|
|
||||||
raise ValueError("BN node target is unexpected ", bn_node.target)
|
|
||||||
bn_eps = bn_args[eps_arg_index]
|
|
||||||
|
|
||||||
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
||||||
|
|
||||||
|
|||||||
@ -2273,14 +2273,22 @@ Call this whenever a new thread is created in order to propagate values from
|
|||||||
py_module.def(
|
py_module.def(
|
||||||
"_select_batch_norm_backend",
|
"_select_batch_norm_backend",
|
||||||
[](const at::Tensor& input,
|
[](const at::Tensor& input,
|
||||||
const at::Tensor& weight,
|
const std::optional<at::Tensor>& weight,
|
||||||
const at::Tensor& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
const at::Tensor& running_mean,
|
const std::optional<at::Tensor>& running_mean,
|
||||||
const at::Tensor& running_var,
|
const std::optional<at::Tensor>& running_var,
|
||||||
bool training,
|
bool training,
|
||||||
double eps) {
|
double eps) {
|
||||||
|
auto weight_opt = weight.has_value() ? weight.value() : at::empty({});
|
||||||
|
auto bias_opt = bias.has_value() ? bias.value() : at::empty({});
|
||||||
return at::native::_select_batch_norm_backend(
|
return at::native::_select_batch_norm_backend(
|
||||||
input, weight, bias, running_mean, running_var, training, eps);
|
input,
|
||||||
|
weight_opt,
|
||||||
|
bias_opt,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
training,
|
||||||
|
eps);
|
||||||
},
|
},
|
||||||
py::arg("input"),
|
py::arg("input"),
|
||||||
py::arg("weight"),
|
py::arg("weight"),
|
||||||
|
|||||||
@ -7146,6 +7146,74 @@ Tensor values_backward(const Tensor& grad, const Tensor& self) {
|
|||||||
return grad_self;
|
return grad_self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A version of `batch_norm_jvp` used only for `_batch_norm_no_update` that
|
||||||
|
* sets the `train` flag based on whether the running stats are defined,
|
||||||
|
* consistent with the behavior in `_batch_norm_no_update`.
|
||||||
|
*/
|
||||||
|
Tensor _batch_norm_no_update_jvp_manual(
|
||||||
|
const Tensor& input_p,
|
||||||
|
const Tensor& input_t,
|
||||||
|
const Tensor& weight_p,
|
||||||
|
const Tensor& weight_t,
|
||||||
|
const Tensor& bias_p,
|
||||||
|
const Tensor& bias_t,
|
||||||
|
const std::optional<Tensor>& running_mean,
|
||||||
|
const std::optional<Tensor>& running_var,
|
||||||
|
const Tensor& saved_mean,
|
||||||
|
const Tensor& saved_invstd,
|
||||||
|
double eps) {
|
||||||
|
const bool has_running_mean = running_mean.has_value() && running_mean->defined();
|
||||||
|
const bool has_running_var = running_var.has_value() && running_var->defined();
|
||||||
|
const bool train = !has_running_mean || !has_running_var;
|
||||||
|
return batch_norm_jvp(
|
||||||
|
input_p,
|
||||||
|
input_t,
|
||||||
|
weight_p,
|
||||||
|
weight_t,
|
||||||
|
bias_p,
|
||||||
|
bias_t,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
saved_mean,
|
||||||
|
saved_invstd,
|
||||||
|
train,
|
||||||
|
eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A version of `batch_norm_backward` used only for `_batch_norm_no_update`
|
||||||
|
* that sets the `train` flag based on whether the running stats are defined,
|
||||||
|
* consistent with the behavior in `_batch_norm_no_update`.
|
||||||
|
*/
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_no_update_backward_manual(
|
||||||
|
const Tensor& grad_output,
|
||||||
|
const Tensor& input,
|
||||||
|
const std::optional<Tensor>& weight_opt,
|
||||||
|
const std::optional<Tensor>& running_mean_opt,
|
||||||
|
const std::optional<Tensor>& running_var_opt,
|
||||||
|
const std::optional<Tensor>& save_mean_opt,
|
||||||
|
const std::optional<Tensor>& save_var_opt,
|
||||||
|
double eps,
|
||||||
|
std::array<bool,3> grad_input_mask,
|
||||||
|
const Tensor& reserve) {
|
||||||
|
const bool has_running_mean = running_mean_opt.has_value() && running_mean_opt->defined();
|
||||||
|
const bool has_running_var = running_var_opt.has_value() && running_var_opt->defined();
|
||||||
|
const bool train = !has_running_mean || !has_running_var;
|
||||||
|
return at::batch_norm_backward(
|
||||||
|
grad_output,
|
||||||
|
input,
|
||||||
|
weight_opt,
|
||||||
|
running_mean_opt,
|
||||||
|
running_var_opt,
|
||||||
|
save_mean_opt,
|
||||||
|
save_var_opt,
|
||||||
|
train,
|
||||||
|
eps,
|
||||||
|
grad_input_mask,
|
||||||
|
reserve);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace details
|
} // namespace details
|
||||||
} // namespace generated
|
} // namespace generated
|
||||||
} // namespace autograd
|
} // namespace autograd
|
||||||
|
|||||||
@ -1102,4 +1102,29 @@ mkldnn_rnn_layer_differentiable_backward(
|
|||||||
|
|
||||||
Tensor values_backward(const Tensor& grad, const Tensor& self);
|
Tensor values_backward(const Tensor& grad, const Tensor& self);
|
||||||
|
|
||||||
|
Tensor _batch_norm_no_update_jvp_manual(
|
||||||
|
const Tensor& input_p,
|
||||||
|
const Tensor& input_t,
|
||||||
|
const Tensor& weight_p,
|
||||||
|
const Tensor& weight_t,
|
||||||
|
const Tensor& bias_p,
|
||||||
|
const Tensor& bias_t,
|
||||||
|
const std::optional<Tensor>& running_mean,
|
||||||
|
const std::optional<Tensor>& running_var,
|
||||||
|
const Tensor& saved_mean,
|
||||||
|
const Tensor& saved_invstd,
|
||||||
|
double eps);
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_no_update_backward_manual(
|
||||||
|
const Tensor& grad_output,
|
||||||
|
const Tensor& input,
|
||||||
|
const std::optional<Tensor>& weight_opt,
|
||||||
|
const std::optional<Tensor>& running_mean_opt,
|
||||||
|
const std::optional<Tensor>& running_var_opt,
|
||||||
|
const std::optional<Tensor>& save_mean_opt,
|
||||||
|
const std::optional<Tensor>& save_var_opt,
|
||||||
|
double eps,
|
||||||
|
std::array<bool,3> grad_input_mask,
|
||||||
|
const Tensor& reserve);
|
||||||
|
|
||||||
} // namespace torch::autograd::generated::details
|
} // namespace torch::autograd::generated::details
|
||||||
|
|||||||
Reference in New Issue
Block a user