mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Batch Norm Consolidation (#116092)"
This reverts commit 7b4f70eda519ccd7f28de17689edd43c52743bc9. Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
This commit is contained in:
@ -29,11 +29,6 @@
|
|||||||
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
||||||
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
|
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
|
||||||
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
|
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
|
||||||
#include <ATen/ops/_batch_norm_with_update.h>
|
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
|
||||||
#include <ATen/ops/_batch_norm_no_update.h>
|
|
||||||
#include <ATen/ops/_batch_norm_no_update_native.h>
|
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
|
||||||
#include <ATen/ops/alias.h>
|
#include <ATen/ops/alias.h>
|
||||||
#include <ATen/ops/batch_norm.h>
|
#include <ATen/ops/batch_norm.h>
|
||||||
#include <ATen/ops/batch_norm_native.h>
|
#include <ATen/ops/batch_norm_native.h>
|
||||||
@ -484,58 +479,10 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
|
|||||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchNormBackend _select_batch_norm_backend(
|
|
||||||
const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean,
|
|
||||||
const Tensor& running_var, bool training, double eps) {
|
|
||||||
|
|
||||||
auto& ctx = at::globalContext();
|
|
||||||
bool cudnn_enabled = ctx.userEnabledCuDNN();
|
|
||||||
|
|
||||||
if (
|
|
||||||
input.is_cuda()
|
|
||||||
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
|
|
||||||
&& (input.scalar_type() != at::kHalf
|
|
||||||
|| weight.scalar_type() == at::kFloat)
|
|
||||||
&& weight.defined() && bias.defined()
|
|
||||||
&& ((running_mean.defined() && running_var.defined())
|
|
||||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
|
||||||
&& (input.dim() >= 3)
|
|
||||||
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
|
|
||||||
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
|
|
||||||
&& detail::getCUDAHooks().compiledWithCuDNN()
|
|
||||||
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
|
|
||||||
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
|
|
||||||
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
|
|
||||||
) {
|
|
||||||
return BatchNormBackend::Cudnn;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
input.is_cuda()
|
|
||||||
&& input.dim() <= MIOPEN_DIM_MAX
|
|
||||||
&& input.scalar_type() != at::kDouble
|
|
||||||
&& input.scalar_type() != at::kBFloat16
|
|
||||||
&& (weight.scalar_type() != at::kHalf)
|
|
||||||
&& weight.defined() && bias.defined()
|
|
||||||
&& ((running_mean.defined() && running_var.defined())
|
|
||||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
|
||||||
&& detail::getCUDAHooks().compiledWithMIOpen()
|
|
||||||
&& cudnn_enabled
|
|
||||||
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
|
|
||||||
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
|
|
||||||
) {
|
|
||||||
return BatchNormBackend::Miopen;
|
|
||||||
}
|
|
||||||
|
|
||||||
return BatchNormBackend::Native;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
|
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
|
||||||
// of backends, while enabling it to keep the information about the used backend, so that it can
|
// of backends, while enabling it to keep the information about the used backend, so that it can
|
||||||
// use its corresponding backward implementation.
|
// use its corresponding backward implementation.
|
||||||
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
|
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
|
||||||
// TODO: remove cudnn_enabled arg
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
|
const Tensor& input, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */, const c10::optional<Tensor>& running_mean_opt /* optional */, const c10::optional<Tensor>& running_var_opt /* optional */,
|
||||||
bool training, double momentum, double eps, bool cudnn_enabled) {
|
bool training, double momentum, double eps, bool cudnn_enabled) {
|
||||||
@ -580,9 +527,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
|||||||
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
|
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
|
const bool use_cudnn = (
|
||||||
|
input.is_cuda()
|
||||||
|
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
|
||||||
|
&& (input.scalar_type() != at::kHalf
|
||||||
|
|| weight.scalar_type() == at::kFloat)
|
||||||
|
&& weight.defined() && bias.defined()
|
||||||
|
&& ((running_mean.defined() && running_var.defined())
|
||||||
|
|| (!running_mean.defined() && !running_var.defined() && training))
|
||||||
|
&& (input.dim() >= 3)
|
||||||
|
&& ((input.sym_size(0) <= 880801 && training) // spatial, training
|
||||||
|
||(input.sym_size(0) <= 65535 && !training)) //spatial, eval
|
||||||
|
&& detail::getCUDAHooks().compiledWithCuDNN()
|
||||||
|
&& eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()
|
||||||
|
&& cudnn_enabled && detail::getCUDAHooks().versionCuDNN() >= 5110L
|
||||||
|
&& input.sym_numel() < std::numeric_limits<std::int32_t>::max() // some cuDNN kernels have 32-bit indexing limitations
|
||||||
|
);
|
||||||
|
|
||||||
if (backend == BatchNormBackend::Cudnn) {
|
if (use_cudnn) {
|
||||||
auto input_c = input.contiguous(input.suggest_memory_format());
|
auto input_c = input.contiguous(input.suggest_memory_format());
|
||||||
auto weight_c = weight.contiguous();
|
auto weight_c = weight.contiguous();
|
||||||
auto bias_c = bias.contiguous();
|
auto bias_c = bias.contiguous();
|
||||||
@ -599,7 +561,19 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
|||||||
|
|
||||||
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
||||||
|
|
||||||
if (backend == BatchNormBackend::Miopen) {
|
bool use_miopen = (input.is_cuda()
|
||||||
|
&& input.dim() <= MIOPEN_DIM_MAX
|
||||||
|
&& input.scalar_type() != at::kDouble
|
||||||
|
&& input.scalar_type() != at::kBFloat16
|
||||||
|
&& (weight.scalar_type() != at::kHalf)
|
||||||
|
&& weight.defined() && bias.defined()
|
||||||
|
&& ((running_mean.defined() && running_var.defined())
|
||||||
|
|| (!running_mean.defined() && !running_var.defined() && training))
|
||||||
|
&& detail::getCUDAHooks().compiledWithMIOpen()
|
||||||
|
&& cudnn_enabled
|
||||||
|
);
|
||||||
|
|
||||||
|
if (use_miopen && input.suggest_memory_format() != MemoryFormat::ChannelsLast && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d) {
|
||||||
return std::tuple_cat(
|
return std::tuple_cat(
|
||||||
at::miopen_batch_norm(
|
at::miopen_batch_norm(
|
||||||
input.contiguous(), weight.contiguous(), bias.contiguous(),
|
input.contiguous(), weight.contiguous(), bias.contiguous(),
|
||||||
@ -663,7 +637,6 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
|
|||||||
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
|
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove cudnn_enabled arg
|
|
||||||
Tensor batch_norm(
|
Tensor batch_norm(
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
||||||
@ -674,30 +647,6 @@ Tensor batch_norm(
|
|||||||
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
||||||
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
|
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
|
||||||
training, momentum, eps, cudnn_enabled));
|
training, momentum, eps, cudnn_enabled));
|
||||||
// TODO: switch to the new stack after the 2 week FC window
|
|
||||||
// if (training) {
|
|
||||||
// BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, training, eps);
|
|
||||||
// if (backend == BatchNormBackend::Cudnn || backend == BatchNormBackend::Miopen) {
|
|
||||||
// auto input_c = input;
|
|
||||||
// if (backend == BatchNormBackend::Cudnn) {
|
|
||||||
// input_c = input.contiguous(input.suggest_memory_format());
|
|
||||||
// } else {
|
|
||||||
// input_c = input.contiguous();
|
|
||||||
// }
|
|
||||||
// auto weight_c = weight.contiguous();
|
|
||||||
// auto bias_c = bias.contiguous();
|
|
||||||
// auto rmean_c = running_mean.defined() ? running_mean.contiguous() : running_mean;
|
|
||||||
// auto rvar_c = running_var.defined() ? running_var.contiguous() : running_var;
|
|
||||||
// return std::get<0>(at::_batch_norm_with_update(input_c, weight_c, bias_c, const_cast<Tensor&>(rmean_c),
|
|
||||||
// const_cast<Tensor&>(rvar_c), momentum, eps));
|
|
||||||
// } else {
|
|
||||||
// return std::get<0>(at::_batch_norm_with_update(input, weight, bias, const_cast<Tensor&>(running_mean),
|
|
||||||
// const_cast<Tensor&>(running_var), momentum, eps));
|
|
||||||
// }
|
|
||||||
// } else {
|
|
||||||
// return std::get<0>(at::_batch_norm_no_update(input, weight, bias, running_mean, running_var,
|
|
||||||
// momentum, eps));
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor instance_norm(
|
Tensor instance_norm(
|
||||||
@ -849,38 +798,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const c10:
|
|||||||
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
|
return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cpu(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
|
||||||
Tensor output, save_mean, save_var;
|
|
||||||
std::tie(output, save_mean, save_var) =
|
|
||||||
batch_norm_cpu(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps);
|
|
||||||
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cpu_out(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
|
|
||||||
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
|
|
||||||
std::tie(out, save_mean, save_var) =
|
|
||||||
batch_norm_cpu_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
|
|
||||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_no_update(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
|
||||||
double momentum, double eps) {
|
|
||||||
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
|
||||||
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
|
||||||
Tensor output, save_mean, save_var;
|
|
||||||
std::tie(output, save_mean, save_var) =
|
|
||||||
batch_norm_cpu(input, weight_opt, bias_opt, const_cast<Tensor&>(running_mean), const_cast<Tensor&>(running_var), /*update*/false, momentum, eps);
|
|
||||||
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cpu(
|
||||||
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
||||||
@ -909,13 +826,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cpu_out(const T
|
|||||||
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
|
return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cpu(
|
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
|
||||||
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
|
||||||
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
|
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
|
||||||
return batch_norm_backward_cpu(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
|
||||||
bool train, double eps, std::array<bool,3> grad_input_mask) {
|
bool train, double eps, std::array<bool,3> grad_input_mask) {
|
||||||
|
@ -8,12 +8,4 @@ namespace at::native {
|
|||||||
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
|
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
|
||||||
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
|
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
|
||||||
|
|
||||||
enum class BatchNormBackend {
|
|
||||||
Native,
|
|
||||||
Cudnn,
|
|
||||||
Miopen,
|
|
||||||
};
|
|
||||||
|
|
||||||
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
|
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||||
#include <ATen/detail/CUDAHooksInterface.h>
|
|
||||||
#include <ATen/native/Normalization.h>
|
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/ReduceOps.h>
|
#include <ATen/native/ReduceOps.h>
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
@ -14,8 +12,6 @@
|
|||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
|
||||||
#include <ATen/ops/batch_norm_backward_elemt_native.h>
|
#include <ATen/ops/batch_norm_backward_elemt_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_reduce_native.h>
|
#include <ATen/ops/batch_norm_backward_reduce_native.h>
|
||||||
#include <ATen/ops/batch_norm_elemt_native.h>
|
#include <ATen/ops/batch_norm_elemt_native.h>
|
||||||
@ -23,12 +19,8 @@
|
|||||||
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
|
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
|
||||||
#include <ATen/ops/batch_norm_stats_native.h>
|
#include <ATen/ops/batch_norm_stats_native.h>
|
||||||
#include <ATen/ops/batch_norm_update_stats_native.h>
|
#include <ATen/ops/batch_norm_update_stats_native.h>
|
||||||
#include <ATen/ops/cudnn_batch_norm.h>
|
|
||||||
#include <ATen/ops/cudnn_batch_norm_backward.h>
|
|
||||||
#include <ATen/ops/empty_like.h>
|
#include <ATen/ops/empty_like.h>
|
||||||
#include <ATen/ops/from_blob.h>
|
#include <ATen/ops/from_blob.h>
|
||||||
#include <ATen/ops/miopen_batch_norm.h>
|
|
||||||
#include <ATen/ops/miopen_batch_norm_backward.h>
|
|
||||||
#include <ATen/ops/native_batch_norm_backward_native.h>
|
#include <ATen/ops/native_batch_norm_backward_native.h>
|
||||||
#include <ATen/ops/native_batch_norm_native.h>
|
#include <ATen/ops/native_batch_norm_native.h>
|
||||||
#include <ATen/ops/scalar_tensor.h>
|
#include <ATen/ops/scalar_tensor.h>
|
||||||
@ -481,54 +473,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const c10
|
|||||||
return std::make_tuple(output, save_mean, save_invstd);
|
return std::make_tuple(output, save_mean, save_invstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_cuda(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
|
||||||
// See [Note: hacky wrapper removal for optional tensor]
|
|
||||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
|
||||||
const Tensor& weight = *weight_maybe_owned;
|
|
||||||
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
|
||||||
Tensor output, save_mean, save_var, reserve;
|
|
||||||
|
|
||||||
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
|
|
||||||
if (backend == BatchNormBackend::Cudnn) {
|
|
||||||
std::tie(output, save_mean, save_var, reserve) =
|
|
||||||
at::cudnn_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
|
|
||||||
} else if (backend == BatchNormBackend::Miopen) {
|
|
||||||
reserve = at::empty({0}, input.options().dtype(kByte));
|
|
||||||
std::tie(output, save_mean, save_var) =
|
|
||||||
at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
|
|
||||||
} else {
|
|
||||||
reserve = at::empty({0}, input.options().dtype(kByte));
|
|
||||||
std::tie(output, save_mean, save_var) =
|
|
||||||
batch_norm_cuda(input, weight_opt, bias_opt, running_mean, running_var, /*training*/true, momentum, eps);
|
|
||||||
}
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_cuda_out(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps,
|
|
||||||
Tensor& out, Tensor& save_mean, Tensor& save_var, Tensor& reserve) {
|
|
||||||
// See [Note: hacky wrapper removal for optional tensor]
|
|
||||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
|
||||||
const Tensor& weight = *weight_maybe_owned;
|
|
||||||
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
|
|
||||||
|
|
||||||
BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps);
|
|
||||||
if (backend == BatchNormBackend::Cudnn) {
|
|
||||||
std::tie(out, save_mean, save_var, reserve) =
|
|
||||||
at::cudnn_batch_norm_out(out, save_mean, save_var, reserve, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
|
|
||||||
} else if (backend == BatchNormBackend::Miopen) {
|
|
||||||
std::tie(out, save_mean, save_var) =
|
|
||||||
at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps);
|
|
||||||
} else {
|
|
||||||
std::tie(out, save_mean, save_var) =
|
|
||||||
batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var);
|
|
||||||
}
|
|
||||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_cuda(const Tensor& self, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) {
|
||||||
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
|
return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon);
|
||||||
}
|
}
|
||||||
@ -545,28 +489,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_cuda_out(const
|
|||||||
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
|
return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_cuda(
|
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
|
||||||
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
|
||||||
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
|
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
|
||||||
const Tensor& dummy_bias = at::empty(1);
|
|
||||||
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
|
|
||||||
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
|
|
||||||
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
|
|
||||||
const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();});
|
|
||||||
|
|
||||||
BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps);
|
|
||||||
|
|
||||||
if (backend == BatchNormBackend::Cudnn) {
|
|
||||||
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve);
|
|
||||||
} else if (backend == BatchNormBackend::Miopen) {
|
|
||||||
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps);
|
|
||||||
} else {
|
|
||||||
return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
|
||||||
// See [Note: hacky wrapper removal for optional tensor]
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
|
c10::MaybeOwned<Tensor> weight = at::borrow_from_optional_tensor(weight_opt);
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/cuda/CUDAConfig.h>
|
#include <ATen/cuda/CUDAConfig.h>
|
||||||
#include <ATen/native/cudnn/BatchNorm.h>
|
|
||||||
|
|
||||||
#if !AT_CUDNN_ENABLED()
|
#if !AT_CUDNN_ENABLED()
|
||||||
|
|
||||||
@ -36,24 +35,18 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
|
|||||||
AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
|
AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t _get_cudnn_batch_norm_reserve_space_size(
|
|
||||||
const Tensor& input_t,
|
|
||||||
bool training) {
|
|
||||||
AT_ERROR(
|
|
||||||
"_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
#else // AT_CUDNN_ENABLED
|
#else // AT_CUDNN_ENABLED
|
||||||
|
|
||||||
#include <ATen/TensorUtils.h>
|
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <ATen/cudnn/Descriptors.h>
|
#include <ATen/cudnn/Descriptors.h>
|
||||||
#include <ATen/cudnn/Types.h>
|
#include <ATen/cudnn/Types.h>
|
||||||
#include <ATen/cudnn/Utils.h>
|
#include <ATen/cudnn/Utils.h>
|
||||||
|
|
||||||
|
#include <ATen/TensorUtils.h>
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
@ -98,21 +91,6 @@ cudnnBatchNormMode_t getCudnnBatchNormMode(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
size_t _get_cudnn_batch_norm_reserve_space_size(
|
|
||||||
const Tensor& input_t,
|
|
||||||
bool training) {
|
|
||||||
size_t reserve_size;
|
|
||||||
TensorArg input{input_t, "input", 1};
|
|
||||||
TensorDescriptor idesc{*input, 4};
|
|
||||||
auto handle = getCudnnHandle();
|
|
||||||
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
|
|
||||||
training, input->suggest_memory_format(), input->dim());
|
|
||||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
|
||||||
AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
|
|
||||||
handle, mode, op, nullptr, idesc.desc(), &reserve_size));
|
|
||||||
return reserve_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||||
const Tensor& input_t,
|
const Tensor& input_t,
|
||||||
const Tensor& weight_t,
|
const Tensor& weight_t,
|
||||||
@ -201,8 +179,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
|||||||
Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
|
Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
|
||||||
|
|
||||||
// get the reserved size and allocate as tensor
|
// get the reserved size and allocate as tensor
|
||||||
size_t reserve_size =
|
size_t reserve_size;
|
||||||
_get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */);
|
AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
|
||||||
|
handle, mode, op, nullptr, idesc.desc(), &reserve_size));
|
||||||
reserve = at::empty(reserve_size, input->options().dtype(kByte));
|
reserve = at::empty(reserve_size, input->options().dtype(kByte));
|
||||||
|
|
||||||
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
|
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
namespace at::native {
|
|
||||||
|
|
||||||
TORCH_API size_t
|
|
||||||
_get_cudnn_batch_norm_reserve_space_size(const Tensor& input_t, bool training);
|
|
||||||
|
|
||||||
} // namespace at::native
|
|
@ -6,8 +6,6 @@
|
|||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
|
||||||
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
||||||
#include <ATen/ops/_to_dense_native.h>
|
#include <ATen/ops/_to_dense_native.h>
|
||||||
#include <ATen/ops/empty_native.h>
|
#include <ATen/ops/empty_native.h>
|
||||||
@ -61,20 +59,6 @@ std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
|
|||||||
TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
|
TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
|
||||||
TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
|
||||||
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
|
||||||
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
|
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
|
||||||
TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
@ -208,17 +192,6 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
|
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
|
|
||||||
Tensor output, save_mean, save_var;
|
|
||||||
std::tie(output, save_mean, save_var) =
|
|
||||||
mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps);
|
|
||||||
Tensor reserve = empty_mkldnn({0}, input.scalar_type());
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
|
std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
|
||||||
bool train,
|
bool train,
|
||||||
@ -237,15 +210,6 @@ std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
|
|
||||||
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
|
||||||
const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt,
|
|
||||||
const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_var_opt,
|
|
||||||
bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
|
|
||||||
return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
|
std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
|
||||||
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
|
const Tensor& input, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& running_mean_opt, const c10::optional<Tensor>& running_var_opt, const c10::optional<Tensor>& save_mean_opt, const c10::optional<Tensor>& save_invstd_opt,
|
||||||
bool train,
|
bool train,
|
||||||
|
@ -10,9 +10,7 @@
|
|||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/_batch_norm_with_update_native.h>
|
|
||||||
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
||||||
#include <ATen/ops/batch_norm_backward_native.h>
|
|
||||||
#include <ATen/ops/native_batch_norm.h>
|
#include <ATen/ops/native_batch_norm.h>
|
||||||
#include <ATen/ops/native_batch_norm_backward_native.h>
|
#include <ATen/ops/native_batch_norm_backward_native.h>
|
||||||
#include <ATen/ops/native_batch_norm_native.h>
|
#include <ATen/ops/native_batch_norm_native.h>
|
||||||
@ -408,36 +406,6 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_mps(const Tensor& self,
|
|||||||
return std::make_tuple(output, save_mean, save_var);
|
return std::make_tuple(output, save_mean, save_var);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mps(const Tensor& input,
|
|
||||||
const c10::optional<Tensor>& weight_opt,
|
|
||||||
const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean,
|
|
||||||
Tensor& running_var,
|
|
||||||
double momentum,
|
|
||||||
double eps) {
|
|
||||||
Tensor output, save_mean, save_var;
|
|
||||||
std::tie(output, save_mean, save_var) =
|
|
||||||
batch_norm_mps(input, weight_opt, bias_opt, running_mean, running_var, /*train*/ true, momentum, eps);
|
|
||||||
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
|
|
||||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_mps_out(const Tensor& input,
|
|
||||||
const c10::optional<Tensor>& weight_opt,
|
|
||||||
const c10::optional<Tensor>& bias_opt,
|
|
||||||
Tensor& running_mean,
|
|
||||||
Tensor& running_var,
|
|
||||||
double momentum,
|
|
||||||
double eps,
|
|
||||||
Tensor& out,
|
|
||||||
Tensor& save_mean,
|
|
||||||
Tensor& save_var,
|
|
||||||
Tensor& reserve) {
|
|
||||||
std::tie(out, save_mean, save_var) = batch_norm_mps_out(
|
|
||||||
input, weight_opt, bias_opt, running_mean, running_var, /*update*/ true, momentum, eps, out, save_mean, save_var);
|
|
||||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(out, save_mean, save_var, reserve);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
|
std::tuple<Tensor, Tensor, Tensor> _batch_norm_legit_mps(const Tensor& self,
|
||||||
const c10::optional<Tensor>& weight_opt,
|
const c10::optional<Tensor>& weight_opt,
|
||||||
const c10::optional<Tensor>& bias_opt,
|
const c10::optional<Tensor>& bias_opt,
|
||||||
@ -503,29 +471,6 @@ static string get_mem_string(c10::MemoryFormat memory_format) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Batch norm backward
|
// Batch norm backward
|
||||||
std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mps(const Tensor& grad_output,
|
|
||||||
const Tensor& input,
|
|
||||||
const Tensor& weight,
|
|
||||||
const c10::optional<Tensor>& running_mean_opt,
|
|
||||||
const c10::optional<Tensor>& running_var_opt,
|
|
||||||
const c10::optional<Tensor>& save_mean_opt,
|
|
||||||
const c10::optional<Tensor>& save_var_opt,
|
|
||||||
bool update,
|
|
||||||
double eps,
|
|
||||||
std::array<bool, 3> grad_input_mask,
|
|
||||||
const Tensor& reserve) {
|
|
||||||
return batch_norm_backward_mps(grad_output,
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
running_mean_opt,
|
|
||||||
running_var_opt,
|
|
||||||
save_mean_opt,
|
|
||||||
save_var_opt,
|
|
||||||
update,
|
|
||||||
eps,
|
|
||||||
grad_input_mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out,
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_out,
|
||||||
const Tensor& input,
|
const Tensor& input,
|
||||||
const c10::optional<Tensor>& weight_opt,
|
const c10::optional<Tensor>& weight_opt,
|
||||||
|
@ -6496,32 +6496,6 @@
|
|||||||
SparseCPU, SparseCUDA: norm_sparse
|
SparseCPU, SparseCUDA: norm_sparse
|
||||||
autogen: native_norm.ScalarOpt_dim_dtype_out
|
autogen: native_norm.ScalarOpt_dim_dtype_out
|
||||||
|
|
||||||
- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
|
||||||
dispatch:
|
|
||||||
CPU: _batch_norm_with_update_cpu
|
|
||||||
CUDA: _batch_norm_with_update_cuda
|
|
||||||
MPS: _batch_norm_with_update_mps
|
|
||||||
MkldnnCPU: _batch_norm_with_update_mkldnn
|
|
||||||
autogen: _batch_norm_with_update_functional
|
|
||||||
|
|
||||||
- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
|
|
||||||
dispatch:
|
|
||||||
CPU: _batch_norm_with_update_cpu_out
|
|
||||||
CUDA: _batch_norm_with_update_cuda_out
|
|
||||||
MPS: _batch_norm_with_update_mps_out
|
|
||||||
|
|
||||||
- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
|
||||||
dispatch:
|
|
||||||
CompositeExplicitAutograd: _batch_norm_no_update
|
|
||||||
autogen: _batch_norm_no_update.out
|
|
||||||
|
|
||||||
- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
|
|
||||||
dispatch:
|
|
||||||
CPU: _new_batch_norm_backward_cpu
|
|
||||||
CUDA: _new_batch_norm_backward_cuda
|
|
||||||
MPS: _new_batch_norm_backward_mps
|
|
||||||
MkldnnCPU: _new_batch_norm_backward_mkldnn
|
|
||||||
|
|
||||||
# TODO: reduce signatures down to one when optional args is available
|
# TODO: reduce signatures down to one when optional args is available
|
||||||
- func: _sparse_sum(Tensor self) -> Tensor
|
- func: _sparse_sum(Tensor self) -> Tensor
|
||||||
|
|
||||||
|
@ -114,7 +114,6 @@ dtensor_fails = {
|
|||||||
xfail("as_strided", "partial_views"),
|
xfail("as_strided", "partial_views"),
|
||||||
xfail("as_strided_scatter"),
|
xfail("as_strided_scatter"),
|
||||||
xfail("bernoulli"),
|
xfail("bernoulli"),
|
||||||
xfail("_batch_norm_with_update"),
|
|
||||||
xfail("block_diag"),
|
xfail("block_diag"),
|
||||||
xfail("broadcast_shapes"),
|
xfail("broadcast_shapes"),
|
||||||
xfail("cauchy"),
|
xfail("cauchy"),
|
||||||
|
@ -6,9 +6,6 @@ aten::_adaptive_avg_pool2d
|
|||||||
aten::_adaptive_avg_pool2d.out
|
aten::_adaptive_avg_pool2d.out
|
||||||
aten::_addmm_activation
|
aten::_addmm_activation
|
||||||
aten::_addmm_activation.out
|
aten::_addmm_activation.out
|
||||||
aten::_batch_norm_no_update
|
|
||||||
aten::_batch_norm_with_update
|
|
||||||
aten::_batch_norm_with_update_functional
|
|
||||||
aten::_euclidean_dist.out
|
aten::_euclidean_dist.out
|
||||||
aten::_fused_dropout
|
aten::_fused_dropout
|
||||||
aten::_fused_dropout.out
|
aten::_fused_dropout.out
|
||||||
@ -79,7 +76,6 @@ aten::atanh
|
|||||||
aten::atanh.out
|
aten::atanh.out
|
||||||
aten::atanh_
|
aten::atanh_
|
||||||
aten::baddbmm_
|
aten::baddbmm_
|
||||||
aten::batch_norm_backward
|
|
||||||
aten::bitwise_and.Scalar
|
aten::bitwise_and.Scalar
|
||||||
aten::bitwise_and.Scalar_Tensor
|
aten::bitwise_and.Scalar_Tensor
|
||||||
aten::bitwise_and.Scalar_Tensor_out
|
aten::bitwise_and.Scalar_Tensor_out
|
||||||
|
@ -30,8 +30,6 @@ aten::_amp_update_scale.out
|
|||||||
aten::_amp_update_scale_
|
aten::_amp_update_scale_
|
||||||
aten::_assert_async
|
aten::_assert_async
|
||||||
aten::_assert_async.msg
|
aten::_assert_async.msg
|
||||||
aten::_batch_norm_no_update.out
|
|
||||||
aten::_batch_norm_with_update.out
|
|
||||||
aten::_cdist_backward
|
aten::_cdist_backward
|
||||||
aten::_cdist_backward.out
|
aten::_cdist_backward.out
|
||||||
aten::_cdist_forward
|
aten::_cdist_forward
|
||||||
|
@ -11,7 +11,7 @@ import unittest
|
|||||||
|
|
||||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \
|
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \
|
||||||
IS_X86, parametrize, TEST_WITH_ASAN, TEST_WITH_ROCM, noncontiguous_like
|
IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like
|
||||||
from torch.testing._internal.common_utils import skipIfRocm, runOnRocm
|
from torch.testing._internal.common_utils import skipIfRocm, runOnRocm
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -368,10 +368,6 @@ aliasing_ops_list_return = {
|
|||||||
# 'tensor_split' not composite compliant, see vjp_fail
|
# 'tensor_split' not composite compliant, see vjp_fail
|
||||||
}
|
}
|
||||||
|
|
||||||
skip_noncontig = {
|
|
||||||
'_batch_norm_with_update',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
|
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
|
||||||
@unMarkDynamoStrictTest
|
@unMarkDynamoStrictTest
|
||||||
@ -400,14 +396,6 @@ class TestOperators(TestCase):
|
|||||||
xfail('nn.functional.scaled_dot_product_attention'),
|
xfail('nn.functional.scaled_dot_product_attention'),
|
||||||
xfail("torch.ops.aten._flash_attention_forward"),
|
xfail("torch.ops.aten._flash_attention_forward"),
|
||||||
xfail("torch.ops.aten._efficient_attention_forward"),
|
xfail("torch.ops.aten._efficient_attention_forward"),
|
||||||
|
|
||||||
# RuntimeError: Expected contiguous tensor, but got
|
|
||||||
# non-contiguous tensor for argument #2 'grad_output'
|
|
||||||
decorate(
|
|
||||||
'_batch_norm_with_update',
|
|
||||||
decorator=expectedFailureIf(TEST_WITH_ROCM),
|
|
||||||
device_type='cuda',
|
|
||||||
)
|
|
||||||
}))
|
}))
|
||||||
@opsToleranceOverride('TestOperators', 'test_grad', (
|
@opsToleranceOverride('TestOperators', 'test_grad', (
|
||||||
tol1('nn.functional.binary_cross_entropy_with_logits',
|
tol1('nn.functional.binary_cross_entropy_with_logits',
|
||||||
@ -445,10 +433,9 @@ class TestOperators(TestCase):
|
|||||||
args = [sample.input] + list(sample.args)
|
args = [sample.input] + list(sample.args)
|
||||||
kwargs = sample.kwargs
|
kwargs = sample.kwargs
|
||||||
|
|
||||||
if op.name not in skip_noncontig:
|
noncontig_sample = sample.noncontiguous()
|
||||||
noncontig_sample = sample.noncontiguous()
|
noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
|
||||||
noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
|
noncontig_kwargs = noncontig_sample.kwargs
|
||||||
noncontig_kwargs = noncontig_sample.kwargs
|
|
||||||
|
|
||||||
diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
|
diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
|
||||||
assert len(diff_argnums) > 0
|
assert len(diff_argnums) > 0
|
||||||
@ -471,12 +458,11 @@ class TestOperators(TestCase):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
|
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
|
||||||
|
result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
|
||||||
expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
|
expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
|
||||||
self.assertEqual(result, expected)
|
|
||||||
|
|
||||||
if op.name not in skip_noncontig:
|
self.assertEqual(result, expected)
|
||||||
result_noncontig = grad(wrapped_fn, diff_argnums)(*noncontig_args, **noncontig_kwargs)
|
self.assertEqual(result_noncontig, expected)
|
||||||
self.assertEqual(result_noncontig, expected)
|
|
||||||
|
|
||||||
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
||||||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
||||||
@ -490,8 +476,7 @@ class TestOperators(TestCase):
|
|||||||
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
|
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
|
||||||
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
|
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
|
||||||
xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
||||||
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
||||||
xfail("_batch_norm_with_update"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
|
||||||
|
|
||||||
xfail('nn.functional.scaled_dot_product_attention'),
|
xfail('nn.functional.scaled_dot_product_attention'),
|
||||||
xfail('torch.ops.aten._flash_attention_forward'),
|
xfail('torch.ops.aten._flash_attention_forward'),
|
||||||
@ -560,17 +545,15 @@ class TestOperators(TestCase):
|
|||||||
self.jvp_opinfo_test(outplace_variant, sample,
|
self.jvp_opinfo_test(outplace_variant, sample,
|
||||||
sample.output_process_fn_grad,
|
sample.output_process_fn_grad,
|
||||||
clone_inputs=False,
|
clone_inputs=False,
|
||||||
fixme_ref_jvp_local=fixme_ref_jvp_local,
|
fixme_ref_jvp_local=fixme_ref_jvp_local)
|
||||||
test_noncontig=op.name not in skip_noncontig)
|
|
||||||
if is_valid_inplace_sample_input(sample, op, inplace_variant):
|
if is_valid_inplace_sample_input(sample, op, inplace_variant):
|
||||||
self.jvp_opinfo_test(inplace_variant, sample,
|
self.jvp_opinfo_test(inplace_variant, sample,
|
||||||
sample.output_process_fn_grad,
|
sample.output_process_fn_grad,
|
||||||
clone_inputs=True,
|
clone_inputs=True,
|
||||||
fixme_ref_jvp_local=fixme_ref_jvp_local,
|
fixme_ref_jvp_local=fixme_ref_jvp_local)
|
||||||
test_noncontig=op.name not in skip_noncontig)
|
|
||||||
|
|
||||||
def jvp_opinfo_test(self, fn, sample, output_process_fn,
|
def jvp_opinfo_test(self, fn, sample, output_process_fn,
|
||||||
clone_inputs, fixme_ref_jvp_local, test_noncontig):
|
clone_inputs, fixme_ref_jvp_local):
|
||||||
# NB: we used requires_grad=True to determine where the primals are,
|
# NB: we used requires_grad=True to determine where the primals are,
|
||||||
# but don't need that information otherwise
|
# but don't need that information otherwise
|
||||||
args = (sample.input,) + sample.args
|
args = (sample.input,) + sample.args
|
||||||
@ -580,6 +563,15 @@ class TestOperators(TestCase):
|
|||||||
orig_primals = tree_map(lambda x: x.detach(), primals)
|
orig_primals = tree_map(lambda x: x.detach(), primals)
|
||||||
orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
|
orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
|
||||||
|
|
||||||
|
noncontig_sample = sample.noncontiguous()
|
||||||
|
noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
|
||||||
|
noncontig_kwargs = sample.kwargs
|
||||||
|
noncontig_fn, primals = normalize_op_input_output2(
|
||||||
|
fn, noncontig_args, noncontig_kwargs,
|
||||||
|
output_process_fn, requires_grad=True)
|
||||||
|
noncontig_primals = tree_map(lambda x: x.detach(), primals)
|
||||||
|
noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
|
||||||
|
|
||||||
def maybe_clone_inputs():
|
def maybe_clone_inputs():
|
||||||
if clone_inputs:
|
if clone_inputs:
|
||||||
primals = tree_map(torch.clone, orig_primals)
|
primals = tree_map(torch.clone, orig_primals)
|
||||||
@ -594,24 +586,15 @@ class TestOperators(TestCase):
|
|||||||
primals, tangents = maybe_clone_inputs()
|
primals, tangents = maybe_clone_inputs()
|
||||||
primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
|
primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
|
||||||
|
|
||||||
|
noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
|
||||||
|
noncontig_primals,
|
||||||
|
noncontig_tangents)
|
||||||
|
|
||||||
self.assertEqual(primal_outs, expected_primal_outs)
|
self.assertEqual(primal_outs, expected_primal_outs)
|
||||||
self.assertEqual(tangent_outs, expected_tangent_outs)
|
self.assertEqual(tangent_outs, expected_tangent_outs)
|
||||||
|
|
||||||
if test_noncontig:
|
self.assertEqual(noncontig_primal_outs, expected_primal_outs)
|
||||||
noncontig_sample = sample.noncontiguous()
|
self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
|
||||||
noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
|
|
||||||
noncontig_kwargs = sample.kwargs
|
|
||||||
noncontig_fn, primals = normalize_op_input_output2(
|
|
||||||
fn, noncontig_args, noncontig_kwargs,
|
|
||||||
output_process_fn, requires_grad=True)
|
|
||||||
noncontig_primals = tree_map(lambda x: x.detach(), primals)
|
|
||||||
noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents)
|
|
||||||
noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn,
|
|
||||||
noncontig_primals,
|
|
||||||
noncontig_tangents)
|
|
||||||
|
|
||||||
self.assertEqual(noncontig_primal_outs, expected_primal_outs)
|
|
||||||
self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
|
|
||||||
|
|
||||||
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
||||||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
||||||
@ -672,22 +655,22 @@ class TestOperators(TestCase):
|
|||||||
result = fn(*primals)
|
result = fn(*primals)
|
||||||
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
||||||
|
|
||||||
|
noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
|
||||||
|
noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
|
||||||
|
|
||||||
out, vjp_fn = vjp(fn, *primals)
|
out, vjp_fn = vjp(fn, *primals)
|
||||||
self.assertEqual(out, result)
|
self.assertEqual(out, result)
|
||||||
result_vjps = vjp_fn(cotangents)
|
result_vjps = vjp_fn(cotangents)
|
||||||
|
|
||||||
|
out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
|
||||||
|
self.assertEqual(out_noncontig, result)
|
||||||
|
noncontig_result_vjps = vjp_fn(noncontig_cotangents)
|
||||||
|
|
||||||
_, vjp_fn = ref_vjp(fn, *primals)
|
_, vjp_fn = ref_vjp(fn, *primals)
|
||||||
expected_vjps = vjp_fn(cotangents)
|
expected_vjps = vjp_fn(cotangents)
|
||||||
|
|
||||||
self.assertEqual(result_vjps, expected_vjps)
|
self.assertEqual(result_vjps, expected_vjps)
|
||||||
|
self.assertEqual(noncontig_result_vjps, expected_vjps)
|
||||||
if op.name not in skip_noncontig:
|
|
||||||
noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous())
|
|
||||||
noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents)
|
|
||||||
out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
|
|
||||||
self.assertEqual(out_noncontig, result)
|
|
||||||
noncontig_result_vjps = vjp_fn(noncontig_cotangents)
|
|
||||||
self.assertEqual(noncontig_result_vjps, expected_vjps)
|
|
||||||
|
|
||||||
_test(op)
|
_test(op)
|
||||||
for a_op in op.aliases:
|
for a_op in op.aliases:
|
||||||
@ -847,8 +830,6 @@ class TestOperators(TestCase):
|
|||||||
xfail("to_sparse"),
|
xfail("to_sparse"),
|
||||||
xfail("native_batch_norm"),
|
xfail("native_batch_norm"),
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail("_batch_norm_with_update"),
|
|
||||||
}))
|
}))
|
||||||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
||||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||||
@ -940,8 +921,6 @@ class TestOperators(TestCase):
|
|||||||
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
|
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
|
||||||
skip("native_batch_norm"),
|
skip("native_batch_norm"),
|
||||||
skip("_native_batch_norm_legit"),
|
skip("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
|
||||||
skip("_batch_norm_with_update"),
|
|
||||||
xfail('__getitem__', ''), # dynamic error
|
xfail('__getitem__', ''), # dynamic error
|
||||||
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
|
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
|
||||||
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
|
xfail('nn.functional.gaussian_nll_loss'), # checks var for if any value < 0
|
||||||
@ -1066,8 +1045,6 @@ class TestOperators(TestCase):
|
|||||||
xfail('nn.functional.batch_norm', 'without_cudnn'),
|
xfail('nn.functional.batch_norm', 'without_cudnn'),
|
||||||
xfail("native_batch_norm"),
|
xfail("native_batch_norm"),
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail("_batch_norm_with_update"),
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/96560
|
# https://github.com/pytorch/pytorch/issues/96560
|
||||||
# ROCm: NotImplementedError
|
# ROCm: NotImplementedError
|
||||||
@ -1253,8 +1230,6 @@ class TestOperators(TestCase):
|
|||||||
xfail('sparse.mm', 'reduce'),
|
xfail('sparse.mm', 'reduce'),
|
||||||
xfail("native_batch_norm"),
|
xfail("native_batch_norm"),
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail("_batch_norm_with_update"),
|
|
||||||
xfail("native_dropout_backward"),
|
xfail("native_dropout_backward"),
|
||||||
xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled
|
xfail("index_fill"), # aten::_unique hit the vmap fallback which is currently disabled
|
||||||
}))
|
}))
|
||||||
@ -1331,8 +1306,6 @@ class TestOperators(TestCase):
|
|||||||
xfail('sparse.mm', 'reduce'),
|
xfail('sparse.mm', 'reduce'),
|
||||||
xfail("native_batch_norm"),
|
xfail("native_batch_norm"),
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail("_batch_norm_with_update"),
|
|
||||||
xfail('as_strided', 'partial_views'),
|
xfail('as_strided', 'partial_views'),
|
||||||
}))
|
}))
|
||||||
def test_vjpvmap(self, device, dtype, op):
|
def test_vjpvmap(self, device, dtype, op):
|
||||||
@ -1591,8 +1564,6 @@ class TestOperators(TestCase):
|
|||||||
# place, were not batched.
|
# place, were not batched.
|
||||||
xfail("native_batch_norm"),
|
xfail("native_batch_norm"),
|
||||||
xfail("_native_batch_norm_legit"),
|
xfail("_native_batch_norm_legit"),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail("_batch_norm_with_update"),
|
|
||||||
xfail('native_dropout_backward'),
|
xfail('native_dropout_backward'),
|
||||||
}))
|
}))
|
||||||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
||||||
@ -1837,14 +1808,6 @@ class TestOperators(TestCase):
|
|||||||
skip('sparse.sampled_addmm', ''),
|
skip('sparse.sampled_addmm', ''),
|
||||||
skip('sparse.mm', 'reduce'),
|
skip('sparse.mm', 'reduce'),
|
||||||
skip('native_layer_norm', '', device_type='cpu'),
|
skip('native_layer_norm', '', device_type='cpu'),
|
||||||
|
|
||||||
# RuntimeError: Expected contiguous tensor, but got
|
|
||||||
# non-contiguous tensor for argument #2 'grad_output'
|
|
||||||
decorate(
|
|
||||||
'_batch_norm_with_update',
|
|
||||||
decorator=expectedFailureIf(TEST_WITH_ROCM),
|
|
||||||
device_type='cuda',
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
@opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', (
|
@opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', (
|
||||||
tol1('linalg.householder_product',
|
tol1('linalg.householder_product',
|
||||||
|
@ -3625,8 +3625,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||||||
# which will be updated in place, were not batched.
|
# which will be updated in place, were not batched.
|
||||||
xfail('native_batch_norm'),
|
xfail('native_batch_norm'),
|
||||||
xfail('_native_batch_norm_legit'),
|
xfail('_native_batch_norm_legit'),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail('_batch_norm_with_update'),
|
|
||||||
xfail('tril'), # Exception not raised on error input
|
xfail('tril'), # Exception not raised on error input
|
||||||
xfail('triu'), # Exception not raised on error input
|
xfail('triu'), # Exception not raised on error input
|
||||||
xfail('as_strided', 'partial_views'),
|
xfail('as_strided', 'partial_views'),
|
||||||
@ -3666,8 +3664,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||||||
# which will be updated in place, were not batched.
|
# which will be updated in place, were not batched.
|
||||||
xfail('native_batch_norm'),
|
xfail('native_batch_norm'),
|
||||||
xfail('_native_batch_norm_legit'),
|
xfail('_native_batch_norm_legit'),
|
||||||
# TODO: implement batching rule
|
|
||||||
xfail('_batch_norm_with_update'),
|
|
||||||
xfail('histogram'),
|
xfail('histogram'),
|
||||||
xfail('scatter_reduce', 'sum'),
|
xfail('scatter_reduce', 'sum'),
|
||||||
xfail('scatter_reduce', 'mean'),
|
xfail('scatter_reduce', 'mean'),
|
||||||
|
@ -192,7 +192,6 @@ inductor_skips["cuda"] = {
|
|||||||
"nn.functional.cosine_embedding_loss": {b8},
|
"nn.functional.cosine_embedding_loss": {b8},
|
||||||
"native_batch_norm": {f16, f32, f64},
|
"native_batch_norm": {f16, f32, f64},
|
||||||
"_native_batch_norm_legit": {f16, f32, f64},
|
"_native_batch_norm_legit": {f16, f32, f64},
|
||||||
"_batch_norm_with_update": {f16, f32, f64},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if not SM80OrLater:
|
if not SM80OrLater:
|
||||||
|
@ -157,11 +157,6 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||||||
dtypes=(torch.float16,),
|
dtypes=(torch.float16,),
|
||||||
reason="fixme: Assertion error: result mismatch and type error",
|
reason="fixme: Assertion error: result mismatch and type error",
|
||||||
),
|
),
|
||||||
skip(
|
|
||||||
"_batch_norm_with_update",
|
|
||||||
dtypes=(torch.float16,),
|
|
||||||
reason="fixme: Assertion error: result mismatch and type error",
|
|
||||||
),
|
|
||||||
xfail(
|
xfail(
|
||||||
"_softmax_backward_data",
|
"_softmax_backward_data",
|
||||||
reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
|
reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
|
||||||
@ -1357,20 +1352,6 @@ SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
|
|||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||||
reason="https://github.com/pytorch/pytorch/issues/115106",
|
reason="https://github.com/pytorch/pytorch/issues/115106",
|
||||||
),
|
),
|
||||||
skip(
|
|
||||||
"_batch_norm_with_update",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
|
||||||
reason="https://github.com/pytorch/pytorch/issues/115106",
|
|
||||||
),
|
|
||||||
# TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]).
|
|
||||||
# Numerically the ONNX program is correct, but the output shapes for `save_mean`
|
|
||||||
# and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268])
|
|
||||||
# for example.
|
|
||||||
skip(
|
|
||||||
"_batch_norm_with_update",
|
|
||||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
|
||||||
reason="not supported yet",
|
|
||||||
),
|
|
||||||
xfail(
|
xfail(
|
||||||
"addmm", # xfail can't only use dtypes to catch all cases
|
"addmm", # xfail can't only use dtypes to catch all cases
|
||||||
matcher=lambda sample: sample.input.dtype
|
matcher=lambda sample: sample.input.dtype
|
||||||
|
@ -22,7 +22,7 @@ torch._C._get_graph_executor_optimize(True)
|
|||||||
|
|
||||||
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
|
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
|
||||||
enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \
|
enable_profiling_mode_for_profiling_tests, slowTest, skipIfTorchDynamo, TEST_WITH_ASAN, \
|
||||||
TEST_WITH_ROCM, IS_FBCODE
|
IS_FBCODE
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, \
|
from torch.testing._internal.jit_utils import JitTestCase, \
|
||||||
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
|
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
|
||||||
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
|
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
|
||||||
@ -2202,7 +2202,6 @@ class TestTEFuser(JitTestCase):
|
|||||||
|
|
||||||
@skipIfTorchDynamo("too slow")
|
@skipIfTorchDynamo("too slow")
|
||||||
@unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
|
@unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans")
|
|
||||||
def test_batch_norm(self):
|
def test_batch_norm(self):
|
||||||
def test(fn, args):
|
def test(fn, args):
|
||||||
trace = torch.jit.trace(fn, args)
|
trace = torch.jit.trace(fn, args)
|
||||||
|
@ -708,11 +708,8 @@ meta_function_device_expected_failures_only_outplace = defaultdict(dict)
|
|||||||
meta_function_device_skips = defaultdict(dict)
|
meta_function_device_skips = defaultdict(dict)
|
||||||
|
|
||||||
meta_function_device_expected_failures['cpu'] = {
|
meta_function_device_expected_failures['cpu'] = {
|
||||||
# TODO: The decomps for these batch norm ops return different dtypes depending
|
|
||||||
# on the device. We should make this work better with meta tensors.
|
|
||||||
torch.native_batch_norm: {bf16, f16},
|
torch.native_batch_norm: {bf16, f16},
|
||||||
torch._native_batch_norm_legit: {bf16, f16},
|
torch._native_batch_norm_legit: {bf16, f16},
|
||||||
torch.ops.aten._batch_norm_with_update: {bf16, f16},
|
|
||||||
torch.native_layer_norm: {bf16, f16},
|
torch.native_layer_norm: {bf16, f16},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -727,11 +724,8 @@ meta_function_device_expected_failures['cuda'] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta_function_device_skips['cpu'] = {
|
meta_function_device_skips['cpu'] = {
|
||||||
# TODO: The decomps for these batch norm ops return different dtypes depending
|
|
||||||
# on the device. We should make this work better with meta tensors.
|
|
||||||
torch.native_batch_norm: {f32, f64},
|
torch.native_batch_norm: {f32, f64},
|
||||||
torch._native_batch_norm_legit: {f32, f64},
|
torch._native_batch_norm_legit: {f32, f64},
|
||||||
torch.ops.aten._batch_norm_with_update: {f32, f64},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
meta_function_device_skips['cuda'] = {
|
meta_function_device_skips['cuda'] = {
|
||||||
@ -856,13 +850,9 @@ meta_dispatch_device_expected_failures = defaultdict(dict)
|
|||||||
meta_dispatch_device_skips = defaultdict(dict)
|
meta_dispatch_device_skips = defaultdict(dict)
|
||||||
|
|
||||||
meta_dispatch_device_expected_failures['cpu'] = {
|
meta_dispatch_device_expected_failures['cpu'] = {
|
||||||
# TODO: The decomps for these batch norm ops return different dtypes depending
|
|
||||||
# on the device. We should make this work better with meta tensors.
|
|
||||||
aten.native_batch_norm.default: {bf16, f16},
|
aten.native_batch_norm.default: {bf16, f16},
|
||||||
aten._native_batch_norm_legit.default: {bf16, f16},
|
aten._native_batch_norm_legit.default: {bf16, f16},
|
||||||
aten._native_batch_norm_legit.no_stats: {bf16, f16},
|
aten._native_batch_norm_legit.no_stats: {bf16, f16},
|
||||||
aten._batch_norm_with_update.default: {bf16, f16},
|
|
||||||
|
|
||||||
aten.native_layer_norm.default: {bf16, f16},
|
aten.native_layer_norm.default: {bf16, f16},
|
||||||
aten.histc.default: {f16},
|
aten.histc.default: {f16},
|
||||||
aten.histc.out: {f16},
|
aten.histc.out: {f16},
|
||||||
@ -887,13 +877,9 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||||||
|
|
||||||
meta_dispatch_device_skips['cpu'] = {
|
meta_dispatch_device_skips['cpu'] = {
|
||||||
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
|
aten._embedding_bag_forward_only.default: {bf16, f16, f32, f64},
|
||||||
|
|
||||||
# TODO: The decomps for these batch norm ops return different dtypes depending
|
|
||||||
# on the device. We should make this work better with meta tensors.
|
|
||||||
aten.native_batch_norm.default: {f32, f64},
|
aten.native_batch_norm.default: {f32, f64},
|
||||||
aten._native_batch_norm_legit.default: {f32, f64},
|
aten._native_batch_norm_legit.default: {f32, f64},
|
||||||
aten._native_batch_norm_legit.no_stats: {f32, f64},
|
aten._native_batch_norm_legit.no_stats: {f32, f64},
|
||||||
aten._batch_norm_with_update.default: {f32, f64},
|
|
||||||
|
|
||||||
# If the computation dtype is different from the input
|
# If the computation dtype is different from the input
|
||||||
# dtype this will fail. CPU execution may also have a
|
# dtype this will fail. CPU execution may also have a
|
||||||
|
@ -11395,7 +11395,6 @@ class TestConsistency(TestCaseMPS):
|
|||||||
'nn.functional.gelu',
|
'nn.functional.gelu',
|
||||||
'nn.functional.glu',
|
'nn.functional.glu',
|
||||||
'_native_batch_norm_legit',
|
'_native_batch_norm_legit',
|
||||||
'_batch_norm_with_update',
|
|
||||||
'native_batch_norm',
|
'native_batch_norm',
|
||||||
'softmax',
|
'softmax',
|
||||||
'_softmax_backward_data',
|
'_softmax_backward_data',
|
||||||
|
@ -1927,15 +1927,6 @@ inplace_symbolic_tensor_failures = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
out_symbolic_tensor_failures = {
|
out_symbolic_tensor_failures = {
|
||||||
# Cast error details: Unable to cast (...) to Tensor
|
|
||||||
#
|
|
||||||
# This happens because the test is set up to call the out variant using the `out` kwarg:
|
|
||||||
# torch._some_op(arg1, arg2, out=(out1, out2, out3))
|
|
||||||
#
|
|
||||||
# However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
|
|
||||||
# this fails because the op has no python bindings, so it doesn't support the `out` kwarg
|
|
||||||
# way of calling its out variant.
|
|
||||||
xfail('_batch_norm_with_update', ''),
|
|
||||||
xfail('_native_batch_norm_legit', ''),
|
xfail('_native_batch_norm_legit', ''),
|
||||||
xfail('angle', ''),
|
xfail('angle', ''),
|
||||||
xfail('argmax', ''),
|
xfail('argmax', ''),
|
||||||
|
@ -1250,20 +1250,6 @@
|
|||||||
self: grad.neg()
|
self: grad.neg()
|
||||||
result: auto_element_wise
|
result: auto_element_wise
|
||||||
|
|
||||||
- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
|
||||||
input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
|
|
||||||
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
|
|
||||||
|
|
||||||
- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
|
||||||
input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple<Tensor, Tensor, Tensor>()"
|
|
||||||
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
|
|
||||||
|
|
||||||
- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
|
|
||||||
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
|
|
||||||
save_mean: not_implemented("batch_norm_backward save_mean")
|
|
||||||
save_var: not_implemented("batch_norm_backward save_var")
|
|
||||||
reserve: not_implemented("batch_norm_backward reserve")
|
|
||||||
|
|
||||||
- name: nextafter(Tensor self, Tensor other) -> Tensor
|
- name: nextafter(Tensor self, Tensor other) -> Tensor
|
||||||
self: not_implemented("nextafter")
|
self: not_implemented("nextafter")
|
||||||
other: not_implemented("nextafter")
|
other: not_implemented("nextafter")
|
||||||
|
@ -160,12 +160,9 @@ _SKIP_PYTHON_BINDINGS = [
|
|||||||
"fill.Tensor", # only used by the functionalization pass
|
"fill.Tensor", # only used by the functionalization pass
|
||||||
"fill.Scalar", # only used by the functionalization pass
|
"fill.Scalar", # only used by the functionalization pass
|
||||||
"lift.*",
|
"lift.*",
|
||||||
"normal_functional", # only used by the functionalization pass
|
"normal_functional", # only used by the functionalization pas
|
||||||
"nbytes",
|
"nbytes",
|
||||||
"itemsize",
|
"itemsize",
|
||||||
"_batch_norm_with_update",
|
|
||||||
"_batch_norm_with_update_out",
|
|
||||||
"_batch_norm_no_update",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
SKIP_PYTHON_BINDINGS = [
|
SKIP_PYTHON_BINDINGS = [
|
||||||
|
@ -1132,7 +1132,6 @@ def _meta_in_tls_dispatch_include() -> _bool: ...
|
|||||||
def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
|
def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
|
||||||
def _get_obj_in_tls(key: str) -> Any: ...
|
def _get_obj_in_tls(key: str) -> Any: ...
|
||||||
def _is_key_in_tls(key: str) -> _bool: ...
|
def _is_key_in_tls(key: str) -> _bool: ...
|
||||||
def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
|
|
||||||
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
|
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
|
||||||
def _conv_determine_backend_memory_format(
|
def _conv_determine_backend_memory_format(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
@ -1198,8 +1197,6 @@ class _LinalgBackend:
|
|||||||
Cusolver: _LinalgBackend
|
Cusolver: _LinalgBackend
|
||||||
Magma: _LinalgBackend
|
Magma: _LinalgBackend
|
||||||
|
|
||||||
class BatchNormBackend(Enum): ...
|
|
||||||
|
|
||||||
class ConvBackend(Enum): ...
|
class ConvBackend(Enum): ...
|
||||||
|
|
||||||
class Tag(Enum):
|
class Tag(Enum):
|
||||||
|
@ -1933,114 +1933,6 @@ def _native_batch_norm_legit_functional(
|
|||||||
return output, save_mean, save_rstd, new_running_mean, new_running_var
|
return output, save_mean, save_rstd, new_running_mean, new_running_var
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_norm_reserve_tensor(
|
|
||||||
input: Tensor,
|
|
||||||
weight: Optional[Tensor],
|
|
||||||
bias: Optional[Tensor],
|
|
||||||
running_mean: Tensor,
|
|
||||||
running_var: Tensor,
|
|
||||||
eps: float,
|
|
||||||
training: bool,
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
|
|
||||||
backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
|
|
||||||
which support a variety of backends including cudnn. We create this tensor here to get
|
|
||||||
the correct shape in the traced graph if we detect that will call the cudnn kernel,
|
|
||||||
and rely on DCE to avoid materializing this tensor.
|
|
||||||
"""
|
|
||||||
backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined]
|
|
||||||
input, weight, bias, running_mean, running_var, True, eps
|
|
||||||
)
|
|
||||||
reserve_size = 0
|
|
||||||
if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined]
|
|
||||||
reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(input, training) # type: ignore[attr-defined]
|
|
||||||
return torch.empty(
|
|
||||||
reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._batch_norm_with_update.default)
|
|
||||||
def _batch_norm_with_update(
|
|
||||||
input: Tensor,
|
|
||||||
weight: Optional[Tensor],
|
|
||||||
bias: Optional[Tensor],
|
|
||||||
running_mean: Tensor,
|
|
||||||
running_var: Tensor,
|
|
||||||
momentum: float,
|
|
||||||
eps: float,
|
|
||||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
||||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
True, # training
|
|
||||||
momentum,
|
|
||||||
eps,
|
|
||||||
False, # functional
|
|
||||||
)
|
|
||||||
reserve = _get_batch_norm_reserve_tensor(
|
|
||||||
input, weight, bias, running_mean, running_var, eps, training=True
|
|
||||||
)
|
|
||||||
return output, save_mean, save_rstd, reserve
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._batch_norm_with_update_functional.default)
|
|
||||||
def _batch_norm_with_update_functional(
|
|
||||||
input: Tensor,
|
|
||||||
weight: Optional[Tensor],
|
|
||||||
bias: Optional[Tensor],
|
|
||||||
running_mean: Tensor,
|
|
||||||
running_var: Tensor,
|
|
||||||
momentum: float,
|
|
||||||
eps: float,
|
|
||||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
||||||
(
|
|
||||||
output,
|
|
||||||
save_mean,
|
|
||||||
save_rstd,
|
|
||||||
new_rm,
|
|
||||||
new_rv,
|
|
||||||
) = native_batch_norm_helper(
|
|
||||||
input, weight, bias, running_mean, running_var, True, momentum, eps, True
|
|
||||||
)
|
|
||||||
reserve = _get_batch_norm_reserve_tensor(
|
|
||||||
input, weight, bias, running_mean, running_var, eps, training=True
|
|
||||||
)
|
|
||||||
assert new_rm is not None, "new_running_mean should not be None"
|
|
||||||
assert new_rv is not None, "new_running_var should not be None"
|
|
||||||
return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._batch_norm_no_update.default)
|
|
||||||
def _batch_norm_no_update(
|
|
||||||
input: Tensor,
|
|
||||||
weight: Optional[Tensor],
|
|
||||||
bias: Optional[Tensor],
|
|
||||||
running_mean: Tensor,
|
|
||||||
running_var: Tensor,
|
|
||||||
momentum: float,
|
|
||||||
eps: float,
|
|
||||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
||||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
False, # training
|
|
||||||
momentum,
|
|
||||||
eps,
|
|
||||||
False, # functional
|
|
||||||
)
|
|
||||||
reserve = _get_batch_norm_reserve_tensor(
|
|
||||||
input, weight, bias, running_mean, running_var, eps, training=False
|
|
||||||
)
|
|
||||||
return output, save_mean, save_rstd, reserve
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten._fused_dropout)
|
@register_decomposition(aten._fused_dropout)
|
||||||
@out_wrapper("out0", "out1")
|
@out_wrapper("out0", "out1")
|
||||||
@pw_cast_for_opmath
|
@pw_cast_for_opmath
|
||||||
@ -2145,34 +2037,6 @@ def _broadcast_batch_norm_backward(x, broadcast_mask):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.batch_norm_backward.default)
|
|
||||||
def batch_norm_backward(
|
|
||||||
grad_out: Tensor,
|
|
||||||
input: Tensor,
|
|
||||||
weight: Optional[Tensor],
|
|
||||||
running_mean: Optional[Tensor],
|
|
||||||
running_var: Optional[Tensor],
|
|
||||||
save_mean: Optional[Tensor],
|
|
||||||
save_invstd: Optional[Tensor],
|
|
||||||
train: bool,
|
|
||||||
eps: float,
|
|
||||||
output_mask: List[bool],
|
|
||||||
reserve: Tensor,
|
|
||||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
||||||
return native_batch_norm_backward(
|
|
||||||
grad_out,
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
save_mean,
|
|
||||||
save_invstd,
|
|
||||||
train,
|
|
||||||
eps,
|
|
||||||
output_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.native_batch_norm_backward.default)
|
@register_decomposition(aten.native_batch_norm_backward.default)
|
||||||
def native_batch_norm_backward(
|
def native_batch_norm_backward(
|
||||||
grad_out: Tensor,
|
grad_out: Tensor,
|
||||||
|
@ -291,34 +291,6 @@ def native_batch_norm_backward(
|
|||||||
return (grad_input, grad_weight, grad_bias)
|
return (grad_input, grad_weight, grad_bias)
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition_for_jvp(aten.batch_norm_backward)
|
|
||||||
def batch_norm_backward(
|
|
||||||
grad_out: Tensor,
|
|
||||||
input: Tensor,
|
|
||||||
weight: Tensor,
|
|
||||||
running_mean: Optional[Tensor],
|
|
||||||
running_var: Optional[Tensor],
|
|
||||||
save_mean: Optional[Tensor],
|
|
||||||
save_var: Optional[Tensor],
|
|
||||||
update: bool,
|
|
||||||
eps: float,
|
|
||||||
output_mask: List[bool],
|
|
||||||
reserve: Tensor,
|
|
||||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
||||||
return native_batch_norm_backward(
|
|
||||||
grad_out,
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
save_mean,
|
|
||||||
save_var,
|
|
||||||
update,
|
|
||||||
eps,
|
|
||||||
output_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
|
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
|
||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
|
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
|
||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
|
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
|
||||||
@ -328,4 +300,3 @@ _register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
|
|||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
|
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
|
||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
|
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
|
||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
|
||||||
_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
|
|
||||||
|
@ -1005,7 +1005,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||||||
"torch._C._scatter_out",
|
"torch._C._scatter_out",
|
||||||
"torch._C._scatter",
|
"torch._C._scatter",
|
||||||
"torch._C._select_conv_backend",
|
"torch._C._select_conv_backend",
|
||||||
"torch._C._select_batch_norm_backend",
|
|
||||||
"torch._C._set_autograd_fallback_mode",
|
"torch._C._set_autograd_fallback_mode",
|
||||||
"torch._C._set_backcompat_broadcast_warn",
|
"torch._C._set_backcompat_broadcast_warn",
|
||||||
"torch._C._set_backcompat_keepdim_warn",
|
"torch._C._set_backcompat_keepdim_warn",
|
||||||
|
@ -753,7 +753,7 @@ def min_cut_rematerialization_partition(
|
|||||||
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
|
recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
|
||||||
|
|
||||||
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
|
||||||
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit, aten._batch_norm_with_update, aten.batch_norm_backward] # noqa: E501,B950
|
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501,B950
|
||||||
|
|
||||||
fusible_ops = recomputable_ops | set(random_ops)
|
fusible_ops = recomputable_ops | set(random_ops)
|
||||||
if AOT_PARTITIONER_DEBUG:
|
if AOT_PARTITIONER_DEBUG:
|
||||||
|
@ -54,10 +54,6 @@ inductor_decompositions = get_decompositions(
|
|||||||
aten._native_batch_norm_legit,
|
aten._native_batch_norm_legit,
|
||||||
aten._native_batch_norm_legit_functional,
|
aten._native_batch_norm_legit_functional,
|
||||||
aten._native_batch_norm_legit_no_training,
|
aten._native_batch_norm_legit_no_training,
|
||||||
aten._batch_norm_with_update,
|
|
||||||
aten._batch_norm_with_update_functional,
|
|
||||||
aten._batch_norm_no_update,
|
|
||||||
aten.batch_norm_backward,
|
|
||||||
aten.native_batch_norm,
|
aten.native_batch_norm,
|
||||||
aten.native_group_norm,
|
aten.native_group_norm,
|
||||||
aten.native_layer_norm,
|
aten.native_layer_norm,
|
||||||
|
@ -18,7 +18,6 @@
|
|||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
#include <ATen/native/ConvUtils.h>
|
#include <ATen/native/ConvUtils.h>
|
||||||
#include <ATen/native/ForeachUtils.h>
|
#include <ATen/native/ForeachUtils.h>
|
||||||
#include <ATen/native/Normalization.h>
|
|
||||||
#include <c10/core/DispatchKeySet.h>
|
#include <c10/core/DispatchKeySet.h>
|
||||||
#include <c10/util/AbortHandler.h>
|
#include <c10/util/AbortHandler.h>
|
||||||
#include <c10/util/Backtrace.h>
|
#include <c10/util/Backtrace.h>
|
||||||
@ -92,10 +91,7 @@
|
|||||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||||
#include <torch/csrc/profiler/combined_traceback.h>
|
#include <torch/csrc/profiler/combined_traceback.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
#include <ATen/cuda/CUDAConfig.h>
|
|
||||||
#include <ATen/native/cudnn/BatchNorm.h>
|
|
||||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -2126,44 +2122,6 @@ Call this whenever a new thread is created in order to propagate values from
|
|||||||
},
|
},
|
||||||
"Checks if a tensor's data pointer is COW");
|
"Checks if a tensor's data pointer is COW");
|
||||||
|
|
||||||
py_module.def(
|
|
||||||
"_get_cudnn_batch_norm_reserve_space_size",
|
|
||||||
[](const at::Tensor& input, bool training) {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
return at::native::_get_cudnn_batch_norm_reserve_space_size(
|
|
||||||
input, training);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "PyTorch was not built with cuda");
|
|
||||||
#endif
|
|
||||||
},
|
|
||||||
py::arg("input"),
|
|
||||||
py::arg("training"));
|
|
||||||
|
|
||||||
py::enum_<at::native::BatchNormBackend>(py_module, "_BatchNormBackend")
|
|
||||||
.value("Native", at::native::BatchNormBackend::Native)
|
|
||||||
.value("Cudnn", at::native::BatchNormBackend::Cudnn)
|
|
||||||
.value("Miopen", at::native::BatchNormBackend::Miopen);
|
|
||||||
|
|
||||||
py_module.def(
|
|
||||||
"_select_batch_norm_backend",
|
|
||||||
[](const at::Tensor& input,
|
|
||||||
const at::Tensor& weight,
|
|
||||||
const at::Tensor& bias,
|
|
||||||
const at::Tensor& running_mean,
|
|
||||||
const at::Tensor& running_var,
|
|
||||||
bool training,
|
|
||||||
double eps) {
|
|
||||||
return at::native::_select_batch_norm_backend(
|
|
||||||
input, weight, bias, running_mean, running_var, training, eps);
|
|
||||||
},
|
|
||||||
py::arg("input"),
|
|
||||||
py::arg("weight"),
|
|
||||||
py::arg("bias"),
|
|
||||||
py::arg("running_mean"),
|
|
||||||
py::arg("running_var"),
|
|
||||||
py::arg("training"),
|
|
||||||
py::arg("eps"));
|
|
||||||
|
|
||||||
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
|
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
|
||||||
THPDefaultCPUGenerator =
|
THPDefaultCPUGenerator =
|
||||||
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
|
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
|
||||||
|
@ -3312,7 +3312,6 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
|
|||||||
{"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
|
{"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
|
||||||
{"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
|
{"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
|
||||||
{"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
|
{"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
|
||||||
{"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "native_batch_norm"},
|
|
||||||
{"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
|
{"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
|
||||||
{"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
|
{"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
|
||||||
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
|
{"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
|
||||||
|
@ -1430,11 +1430,6 @@ add_shape_compute_mapping(
|
|||||||
"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
|
"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
|
||||||
native_batch_norm,
|
native_batch_norm,
|
||||||
)
|
)
|
||||||
add_shape_compute_mapping(
|
|
||||||
"_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
|
|
||||||
native_batch_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_shape_compute_mapping(
|
add_shape_compute_mapping(
|
||||||
"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
|
"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
|
||||||
cross_entropy_loss,
|
cross_entropy_loss,
|
||||||
|
@ -509,19 +509,6 @@ def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad
|
|||||||
else:
|
else:
|
||||||
yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
|
yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
|
||||||
|
|
||||||
def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs):
|
|
||||||
samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
|
|
||||||
for sample in samples:
|
|
||||||
# torch.native_batch_norm does not support 0 numel tensors
|
|
||||||
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
|
|
||||||
if sample.input.numel() == 0:
|
|
||||||
continue
|
|
||||||
args = sample.args
|
|
||||||
momentum = sample.kwargs.get('momentum', 0.5)
|
|
||||||
eps = sample.kwargs.get('eps', 1e-5)
|
|
||||||
if any(args[i] is None for i in range(4)):
|
|
||||||
continue
|
|
||||||
yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps))
|
|
||||||
|
|
||||||
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
@ -13111,42 +13098,6 @@ op_db: List[OpInfo] = [
|
|||||||
"TestCompositeCompliance", "test_forward_ad"),
|
"TestCompositeCompliance", "test_forward_ad"),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
OpInfo('_batch_norm_with_update',
|
|
||||||
op=torch.ops.aten._batch_norm_with_update,
|
|
||||||
aten_name='_batch_norm_with_update',
|
|
||||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
|
||||||
supports_forward_ad=True,
|
|
||||||
supports_fwgrad_bwgrad=True,
|
|
||||||
assert_jit_shape_analysis=True,
|
|
||||||
# TODO: Avoid COW materialize
|
|
||||||
supports_cow_input_no_materialize=False,
|
|
||||||
sample_inputs_func=sample_inputs__batch_norm_with_update,
|
|
||||||
skips=(
|
|
||||||
# NotImplementedError: Could not run
|
|
||||||
# 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
|
|
||||||
# RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
|
|
||||||
# Problem with _get_numerical_jacobian
|
|
||||||
# IndexError: tuple index out of range
|
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
|
||||||
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
|
||||||
# https://github.com/pytorch/pytorch/issues/85960
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
|
|
||||||
DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
|
|
||||||
"TestCompositeCompliance", "test_forward_ad"),
|
|
||||||
# _batch_norm_with_update expects contiguous inputs for cudnn and miopen
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"),
|
|
||||||
DecorateInfo(unittest.expectedFailure,
|
|
||||||
'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"),
|
|
||||||
# _batch_norm_with_update does not have python bindings
|
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
|
||||||
# aten out variants do not accept out= kwarg, only python out variants
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
OpInfo('nn.functional.cosine_similarity',
|
OpInfo('nn.functional.cosine_similarity',
|
||||||
aten_name="cosine_similarity",
|
aten_name="cosine_similarity",
|
||||||
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
||||||
|
Reference in New Issue
Block a user