mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 08:11:08 +08:00
cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve https://github.com/pytorch/pytorch/issues/166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167111 Approved by: https://github.com/Skylion007
2285 lines
94 KiB
C++
2285 lines
94 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/Config.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/TensorOperators.h>
|
|
#include <ATen/native/CanUse32BitIndexMath.h>
|
|
#include <ATen/native/ConvolutionMM3d.h>
|
|
#include <ATen/native/ConvUtils.h>
|
|
#include <ATen/native/Pool.h>
|
|
#include <ATen/native/cpu/DepthwiseConvKernel.h>
|
|
#include <ATen/native/utils/ParamUtils.h>
|
|
#include <ATen/native/xnnpack/Engine.h>
|
|
#include <c10/core/GradMode.h>
|
|
#include <c10/util/accumulate.h>
|
|
#include <c10/util/irange.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <utility>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/permute.h>
|
|
#endif
|
|
|
|
#if AT_NNPACK_ENABLED()
|
|
#include <nnpack.h>
|
|
#endif
|
|
|
|
#if AT_MKLDNN_ENABLED()
|
|
#include <ATen/native/mkldnn/Utils.h>
|
|
#endif
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#else
|
|
#include <ATen/ops/_conv_depthwise2d.h>
|
|
#include <ATen/ops/_convolution.h>
|
|
#include <ATen/ops/_convolution_double_backward_native.h>
|
|
#include <ATen/ops/_convolution_mode.h>
|
|
#include <ATen/ops/_convolution_mode_native.h>
|
|
#include <ATen/ops/_convolution_native.h>
|
|
#include <ATen/ops/_mps_convolution.h>
|
|
#include <ATen/ops/_mps_convolution_transpose.h>
|
|
#include <ATen/ops/_nnpack_available.h>
|
|
#include <ATen/ops/_nnpack_spatial_convolution.h>
|
|
#include <ATen/ops/_slow_conv2d_backward.h>
|
|
#include <ATen/ops/_unsafe_view.h>
|
|
#include <ATen/ops/cat.h>
|
|
#include <ATen/ops/constant_pad_nd.h>
|
|
#include <ATen/ops/conv1d_native.h>
|
|
#include <ATen/ops/conv2d_native.h>
|
|
#include <ATen/ops/conv3d_native.h>
|
|
#include <ATen/ops/conv_depthwise3d.h>
|
|
#include <ATen/ops/conv_transpose1d_native.h>
|
|
#include <ATen/ops/conv_transpose2d_native.h>
|
|
#include <ATen/ops/conv_transpose3d_native.h>
|
|
#include <ATen/ops/convolution.h>
|
|
#include <ATen/ops/convolution_backward_native.h>
|
|
#include <ATen/ops/convolution_backward_overrideable.h>
|
|
#include <ATen/ops/convolution_backward_overrideable_native.h>
|
|
#include <ATen/ops/convolution_native.h>
|
|
#include <ATen/ops/convolution_overrideable.h>
|
|
#include <ATen/ops/convolution_overrideable_native.h>
|
|
#include <ATen/ops/cudnn_convolution.h>
|
|
#include <ATen/ops/cudnn_convolution_transpose.h>
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/empty_like.h>
|
|
#include <ATen/ops/empty_native.h>
|
|
#include <ATen/ops/miopen_convolution.h>
|
|
#include <ATen/ops/miopen_convolution_transpose.h>
|
|
#include <ATen/ops/miopen_depthwise_convolution.h>
|
|
#include <ATen/ops/mkldnn_convolution.h>
|
|
#include <ATen/ops/mps_convolution_backward.h>
|
|
#include <ATen/ops/mps_convolution_transpose_backward.h>
|
|
#include <ATen/ops/slow_conv3d.h>
|
|
#include <ATen/ops/slow_conv_dilated2d.h>
|
|
#include <ATen/ops/slow_conv_dilated3d.h>
|
|
#include <ATen/ops/slow_conv_transpose2d.h>
|
|
#include <ATen/ops/slow_conv_transpose3d.h>
|
|
#include <ATen/ops/thnn_conv2d.h>
|
|
#include <ATen/ops/view_as_real.h>
|
|
#include <ATen/ops/zeros.h>
|
|
#include <ATen/ops/zeros_like.h>
|
|
#endif
|
|
|
|
constexpr int MIOPEN_DIM_MAX = 5;
|
|
|
|
namespace at::native {
|
|
|
|
|
|
static bool conv_benchmark_empty_cache = true;
|
|
|
|
// Check workload to activate fast depthwise FP16 cudnn conv kernels
|
|
template <typename T>
|
|
static bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
|
|
auto w = at::symint::size<T>(input, 3); // same as h
|
|
auto ch = at::symint::size<T>(input, 1);
|
|
auto bs = at::symint::size<T>(input, 0);
|
|
if (stride==1) {
|
|
if (w >= 7) {
|
|
// All batch sizes and nb_channels
|
|
if (w >= 112) {
|
|
return true;
|
|
}
|
|
|
|
// large nb_channels
|
|
if (ch >= 1024) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if (w >= 56) {
|
|
return true;
|
|
} else if (bs >= 32) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// batch_size specific
|
|
if (bs >= 128) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if (ch >= 512) {
|
|
return true;
|
|
} else if (ch >= 64) {
|
|
if (w >= 14) {
|
|
return true;
|
|
}
|
|
} else if ((ch >= 32) && (w >=28)) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 64) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 256) && (w >= 14)) {
|
|
return true;
|
|
} else if ((ch >= 32) && (w >= 28)) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 32) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 256) && (w >= 14)) {
|
|
return true;
|
|
} else if ((ch >= 128) && (w >= 28)) {
|
|
return true;
|
|
} else if ((ch >= 32) && (w >= 56)) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 16) {
|
|
if ((ch >= 1024) && (w >= 14)) {
|
|
return true;
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 256) && (w >= 28)) {
|
|
return true;
|
|
} else if ((ch >= 32) && (w >= 56)) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 8) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 512) && (w >= 28)) {
|
|
return true;
|
|
} else if ((ch >= 64) && (w >= 56)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
} else if (stride==2) {
|
|
if (ch < 256) {
|
|
return false;
|
|
}
|
|
|
|
if (w >= 7) {
|
|
if (bs >= 128) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if (ch >= 1024) {
|
|
return true;
|
|
} else if ((ch >= 512) && (w >= 14)) {
|
|
return true;
|
|
} else if (w >= 28) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 64) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 512) && (w >= 14)) {
|
|
return true;
|
|
} else if (w >= 28) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 32) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 1024) && (w >= 14)) {
|
|
return true;
|
|
} else if (w >= 28) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 16) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 512) && (w >= 28)) {
|
|
return true;
|
|
} else if (w >= 56) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 8) {
|
|
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
|
|
if ((ch >= 1024) && (w >= 28)) {
|
|
return true;
|
|
} else if (w >= 56) {
|
|
return true;
|
|
}
|
|
} else if (bs >= 1) {
|
|
if ((ch >= 512) && (w >=112)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// simplified version for cudnn 8.2 and above
|
|
template <typename T>
|
|
static bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
|
|
// 1D conv
|
|
if(at::symint::size<T>(input, 2) == 1 && stride == 1){
|
|
return true;
|
|
}
|
|
|
|
// 2d conv
|
|
// only square filters
|
|
if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false;
|
|
auto filter = at::symint::size<T>(weight, 3);
|
|
// only 1/3/5 filter
|
|
if (filter != 1 && filter != 3 && filter != 5) return false;
|
|
// we don't enforce square input but only check width to reduce heuristic space
|
|
if (at::symint::size<T>(input, 3) < 7) return false; // min width 7
|
|
auto w = at::symint::size<T>(input, 3);
|
|
// only 1/2 stride, use cudnn for all stride 1
|
|
if (stride == 1) return true;
|
|
if (stride != 2) return false;
|
|
|
|
auto ch = at::symint::size<T>(input, 1);
|
|
auto bs = at::symint::size<T>(input, 0);
|
|
// special case since bs1 show good perf in lots of cases
|
|
if (bs == 1) {
|
|
if (filter == 1 && w <= 28) return true;
|
|
if (filter == 3 || filter == 5) return true;
|
|
} else {
|
|
if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
|
|
if (filter == 3 || filter == 5) {
|
|
if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
|
|
#if defined(C10_MOBILE)
|
|
static bool xnnpack_use_convolution2d(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const at::OptionalIntArrayRef bias_sizes_opt,
|
|
const IntArrayRef padding,
|
|
const IntArrayRef stride,
|
|
const IntArrayRef dilation,
|
|
const int64_t groups,
|
|
const bool transposed) {
|
|
return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed);
|
|
}
|
|
|
|
static bool xnnpack_use_convolution2d(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const at::OptionalSymIntArrayRef bias_sizes_opt,
|
|
const SymIntArrayRef padding,
|
|
const SymIntArrayRef stride,
|
|
const SymIntArrayRef dilation,
|
|
const c10::SymInt groups,
|
|
const bool transposed) {
|
|
// Never use xnnpack for symbolic tracing
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
// This struct is templated so that we can run backend selection in a dynamic
|
|
// shapes context; all of the real kernel selection in eager mode runs with
|
|
// int64_t
|
|
template <typename T>
|
|
struct ConvParams {
|
|
std::vector<T> stride;
|
|
std::vector<T> padding;
|
|
std::vector<T> dilation;
|
|
bool transposed{};
|
|
std::vector<T> output_padding;
|
|
T groups{};
|
|
bool benchmark{};
|
|
bool deterministic{};
|
|
bool cudnn_enabled{};
|
|
bool allow_tf32{};
|
|
|
|
bool is_strided() const {
|
|
return std::any_of(
|
|
stride.cbegin(), stride.cend(), [](const T& s) { return s != 1; });
|
|
}
|
|
|
|
bool is_dilated() const {
|
|
return std::any_of(
|
|
dilation.cbegin(), dilation.cend(), [](const T& d) { return d != 1; });
|
|
}
|
|
|
|
bool is_padded() const {
|
|
return std::any_of(
|
|
padding.cbegin(), padding.cend(), [](const T& p) { return p != 0; });
|
|
}
|
|
|
|
bool is_output_padding_neg() const {
|
|
return std::any_of(
|
|
output_padding.cbegin(),
|
|
output_padding.cend(),
|
|
[](const T& p) { return p < 0; });
|
|
}
|
|
|
|
bool is_output_padding_big() const {
|
|
// Revisit this with std::views::zip at C++20.
|
|
for (auto i: c10::irange(output_padding.size())) {
|
|
if (output_padding[i] >= stride[i]) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool is_padding_neg() const {
|
|
return std::any_of(
|
|
padding.cbegin(), padding.cend(), [](const T& p) { return p < 0; });
|
|
}
|
|
|
|
bool is_dilation_neg() const {
|
|
return std::any_of(
|
|
dilation.cbegin(), dilation.cend(), [](const T& d) { return d < 0; });
|
|
}
|
|
|
|
bool is_stride_nonpos() const {
|
|
return std::any_of(
|
|
stride.cbegin(), stride.cend(), [](const T& s) { return s <= 0; });
|
|
}
|
|
|
|
void view1d_as_2d() {
|
|
if (stride.size() == 1) {
|
|
stride.insert(stride.begin(), 1);
|
|
padding.insert(padding.begin(), 0);
|
|
dilation.insert(dilation.begin(), 1);
|
|
output_padding.insert(output_padding.begin(), 0);
|
|
}
|
|
}
|
|
|
|
bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const std::optional<at::Tensor>& bias) const {
|
|
#if defined(__ARM_NEON__) || (defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000)
|
|
// Currently only 3x3 depthwise convolutions on tensors of float are supported.
|
|
return (input.ndimension() == 4) &&
|
|
(at::symint::size<T>(input, 1) == groups) &&
|
|
(weight.ndimension() == 4 ) &&
|
|
(at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0) &&
|
|
(at::symint::size<T>(weight, 1) == 1) &&
|
|
(at::symint::size<T>(weight, 2) == 3) &&
|
|
(at::symint::size<T>(weight, 3) == 3) &&
|
|
(input.device().is_cpu()) &&
|
|
(input.scalar_type() == at::kFloat) &&
|
|
input.is_contiguous() &&
|
|
(weight.device().is_cpu()) &&
|
|
(weight.scalar_type() == at::kFloat) &&
|
|
weight.is_contiguous() &&
|
|
(!bias.has_value() || bias->is_contiguous()) &&
|
|
!is_strided() &&
|
|
!is_dilated() &&
|
|
!transposed;
|
|
#else
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const {
|
|
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
|
auto numel_input = at::symint::numel<T>(input);
|
|
// empty input
|
|
if (numel_input == 0) {
|
|
return false;
|
|
}
|
|
// input size can not be reduced to the range of int by splitting the batch dim
|
|
auto n = at::symint::size<T>(input, 0);
|
|
if (numel_input / n > int_max) {
|
|
return true;
|
|
}
|
|
// output size can not be reduced to the range of int by splitting the batch dim
|
|
T outsize = 1;
|
|
if (transposed) {
|
|
auto o = conv_input_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, output_padding, stride, dilation, groups);
|
|
outsize = c10::multiply_integers(o.begin() + 1, o.end());
|
|
} else {
|
|
auto o = conv_output_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, stride, dilation);
|
|
outsize = c10::multiply_integers(o.begin() + 1, o.end());
|
|
}
|
|
return outsize > int_max;
|
|
}
|
|
|
|
bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const {
|
|
// Note [Mobile check segfaults]
|
|
// cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest
|
|
// that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how)
|
|
#if !defined(C10_MOBILE)
|
|
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
|
return false;
|
|
}
|
|
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
|
|
// broken on cuDNN 9.8 - 9.14
|
|
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
|
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
|
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
|
|
weight.dim() == 5) {
|
|
for (int i = 2; i < weight.dim(); i++) {
|
|
if (weight.size(i) != 1) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (needs_64bit_indexing_no_split(input, weight)) {
|
|
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
|
|
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
|
|
" if the V8 API is not enabled or before cuDNN version 9.3+."
|
|
" Consider upgrading cuDNN and/or enabling the V8 API for better efficiency.");
|
|
return false;
|
|
}
|
|
}
|
|
if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) {
|
|
if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) {
|
|
return false;
|
|
}
|
|
}
|
|
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) {
|
|
if (is_dilated()) {
|
|
return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big();
|
|
}
|
|
}
|
|
return !is_output_padding_big();
|
|
#else
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
// Use cudnn for FP16 depthwise convolutions
|
|
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
|
|
if (!cudnn_enabled || !detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda()) {
|
|
return false;
|
|
}
|
|
// native kernel doesn't support 64-bit non-splittable case
|
|
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
|
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
|
|
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
|
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
|
if (cudnn_version < 0 || cudnn_version > 91000) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
|
|
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
|
|
" if the V8 API is not enabled or before cuDNN version 9.3+."
|
|
" Upgrade cuDNN or enable the V8 API to use cuDNN for 64-bit depthwise convolutions.");
|
|
return false;
|
|
} else {
|
|
return true;
|
|
}
|
|
}
|
|
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
|
// always use cudnn_depthwise for channels_last format
|
|
return true;
|
|
}
|
|
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
|
|
bool kernel_cond = (use_cudnn(input, weight) &&
|
|
input.scalar_type() == kHalf && // only for FP16
|
|
weight.scalar_type() == kHalf &&
|
|
is_depthwise(input, weight) &&
|
|
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
|
|
!is_dilated() && // no dilation supported
|
|
(stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
|
|
at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
|
|
if (kernel_cond) {
|
|
return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
|
|
}
|
|
return false;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const {
|
|
if (needs_64bit_indexing_no_split(input, weight)) {
|
|
return false;
|
|
}
|
|
return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
|
|
&& cudnn_enabled
|
|
&& input.is_cuda()
|
|
&& detail::getCUDAHooks().compiledWithMIOpen()
|
|
&& input.dim() <= MIOPEN_DIM_MAX
|
|
&& !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
|
|
;
|
|
}
|
|
bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const {
|
|
#if AT_MKLDNN_ENABLED()
|
|
if (!at::globalContext().userEnabledMkldnn()) {
|
|
return false;
|
|
}
|
|
if (transposed && is_output_padding_big()) {
|
|
return false;
|
|
}
|
|
if (input.device().is_cpu() &&
|
|
((input.scalar_type() == at::kBFloat16 && mkldnn_bf16_device_check()) ||
|
|
(input.scalar_type() == at::kHalf && mkldnn_fp16_device_check()))) {
|
|
return true;
|
|
}
|
|
return (input.is_mkldnn()) || // input is mkldnn Tensor
|
|
(input.device().is_cpu() &&
|
|
input.scalar_type() == kFloat && // only on CPU Float Tensors
|
|
// For 1x1 filters, MKLDNN is faster than THNN when multi-threaded,
|
|
// but THNN is faster when single-threaded.
|
|
(is_strided() || is_dilated() || at::symint::size<T>(input, 0) >= 16 ||
|
|
at::symint::size<T>(weight, -1) != 1 || at::symint::size<T>(weight, -2) != 1 || at::get_num_threads() > 1) &&
|
|
(groups > 1
|
|
|| (at::symint::size<T>(weight, -1) > 3 && at::symint::size<T>(weight, -2) > 3)
|
|
|| at::symint::size<T>(input, 0) > 1
|
|
|| at::symint::size<T>(input, 0)*at::symint::size<T>(input, 1)*at::symint::size<T>(input, 2)*at::symint::size<T>(input, 3) > 20480) // for some case, native is faster
|
|
);
|
|
|
|
#endif
|
|
return false;
|
|
}
|
|
bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const {
|
|
#if AT_NNPACK_ENABLED()
|
|
return at::globalContext().userEnabledNNPACK() &&
|
|
at::_nnpack_available() &&
|
|
input.device().is_cpu() &&
|
|
input.scalar_type() == kFloat && // only on CPU Float Tensors
|
|
!is_dilated() && // or dilation
|
|
!transposed && // or transposed tensors
|
|
input.ndimension() == 4 && // must be in NCHW format
|
|
weight.ndimension() == 4 &&
|
|
(at::symint::size<T>(weight, 2) < 17) && (at::symint::size<T>(weight, 3) < 17) && // NNPACK only supports kernels up to 16x16
|
|
(padding[0] < at::symint::size<T>(weight, 2)) && (padding[1] < at::symint::size<T>(weight, 3)) // NNPACK only supports padding < kernel_size. See https://github.com/pytorch/pytorch/issues/90142.
|
|
#if !defined(C10_MOBILE)
|
|
&& at::symint::size<T>(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable
|
|
#endif
|
|
;
|
|
#endif
|
|
return false;
|
|
}
|
|
bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight,
|
|
const at::OptionalArrayRef<T> bias_sizes_opt) const {
|
|
#if defined(C10_MOBILE)
|
|
if (!transposed) {
|
|
// NB: for the call here, it MATTERS that we are templated. If you
|
|
// untemplate this to always use SymInt, the function
|
|
// xnnpack_use_convolution2d will always return false
|
|
return (at::symint::size<T>(input, 1) == groups) &&
|
|
xnnpack_use_convolution2d(
|
|
input,
|
|
weight,
|
|
bias_sizes_opt,
|
|
padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
transposed);
|
|
}
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
bool use_mps(const at::Tensor& input, const at::Tensor& weight) const {
|
|
// These checks need to be expanded. Currently we have very limited set of
|
|
// checks for MPS.
|
|
#ifdef USE_MPS
|
|
if (needs_64bit_indexing_no_split(input, weight)) {
|
|
return false;
|
|
}
|
|
if (!input.is_mps()) {
|
|
return false;
|
|
}
|
|
return true;
|
|
#else
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
// We currently only have depthwise support for the case where groups ==
|
|
// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
|
|
// a depthwise multiplier)
|
|
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
|
|
return input.is_cuda() &&
|
|
!transposed &&
|
|
(input.ndimension() == 4 || input.ndimension() == 5) &&
|
|
at::symint::size<T>(input, 1) == groups &&
|
|
groups > 1 && // no point if there is only a single group
|
|
at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0; // output channels must be a multiple of input channels
|
|
}
|
|
};
|
|
|
|
DEFINE_DISPATCH(conv_depthwise2d_backward_stub);
|
|
DEFINE_DISPATCH(conv_depthwise3d_backward_stub);
|
|
DEFINE_DISPATCH(cudnn_convolution_backward_stub);
|
|
DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub);
|
|
DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub);
|
|
DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub);
|
|
DEFINE_DISPATCH(miopen_convolution_backward_stub);
|
|
DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub);
|
|
DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub);
|
|
DEFINE_DISPATCH(mkldnn_convolution_backward_stub);
|
|
DEFINE_DISPATCH(mkldnn_convolution_transpose_stub);
|
|
DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub);
|
|
DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub);
|
|
DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub);
|
|
DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub);
|
|
REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub)
|
|
REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub)
|
|
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub)
|
|
REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub)
|
|
REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub)
|
|
REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub)
|
|
REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub)
|
|
|
|
template <typename T>
|
|
static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
|
|
out << "ConvParams {"
|
|
<< " stride = " << IntArrayRef{params.stride}
|
|
<< " padding = " << ArrayRef<T>{params.padding}
|
|
<< " dilation = " << IntArrayRef{params.dilation}
|
|
<< " transposed = " << params.transposed
|
|
<< " output_padding = " << ArrayRef<T>{params.output_padding}
|
|
<< " groups = " << params.groups
|
|
<< " benchmark = " << params.benchmark
|
|
<< " deterministic = " << params.deterministic
|
|
<< " cudnn_enabled = " << params.cudnn_enabled
|
|
<< " allow_tf32 = " << params.allow_tf32
|
|
<< "}";
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
static void check_shape_forward(const at::Tensor& input,
|
|
const c10::ArrayRef<T>& weight_sizes, const at::Tensor& bias,
|
|
const ConvParams<T>& params) {
|
|
int64_t k = input.ndimension();
|
|
int64_t weight_dim = weight_sizes.size();
|
|
auto groups = params.groups;
|
|
const auto& padding = params.padding;
|
|
const auto& dilation = params.dilation;
|
|
bool transposed = params.transposed;
|
|
|
|
TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
|
|
TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
|
|
TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
|
|
TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
|
|
TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
|
|
|
|
TORCH_CHECK(weight_dim == k,
|
|
"Expected ", weight_dim, "-dimensional input for ", weight_dim,
|
|
"-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ",
|
|
at::symint::sizes<T>(input), " instead");
|
|
TORCH_CHECK(weight_sizes[0] >= groups,
|
|
"Given groups=", groups, ", expected weight to be at least ", groups,
|
|
" at dimension 0, but got weight of size ", weight_sizes, " instead");
|
|
TORCH_CHECK(weight_sizes[0] % groups == 0,
|
|
"Given groups=", groups, ", expected weight to be divisible by ",
|
|
groups, " at dimension 0, but got weight of size [", weight_sizes,
|
|
"] instead");
|
|
|
|
if (!transposed) {
|
|
std::vector<T> input_shape;
|
|
std::vector<T> kernel_shape;
|
|
bool kernel_size_correct = true;
|
|
|
|
TORCH_CHECK(at::symint::size<T>(input, 1) == (weight_sizes[1] * groups),
|
|
"Given groups=", groups, ", weight of size ", weight_sizes,
|
|
", expected input", at::symint::sizes<T>(input), " to have ",
|
|
(weight_sizes[1] * groups), " channels, but got ", at::symint::size<T>(input, 1),
|
|
" channels instead");
|
|
|
|
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[0]),
|
|
"Given weight of size ", weight_sizes,
|
|
", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
|
|
", but got bias of size ", at::symint::sizes<T>(bias), " instead");
|
|
|
|
for (const auto i : c10::irange(2, k)) {
|
|
// T could be int64_t or SymInt, Specialized numeric_limts<SymInt> in c10/core/SymInt.h
|
|
TORCH_CHECK(padding[i-2] <= (std::numeric_limits<T>::max() - padding[i-2]),
|
|
"Given padding=", padding[i-2], " at dimension ", i-2, " , expected padding to be at most ",
|
|
(std::numeric_limits<T>::max() / 2));
|
|
input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
|
|
// log new kernel size considering dilation
|
|
kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
|
|
if (input_shape.back() < kernel_shape.back()) {
|
|
kernel_size_correct = false;
|
|
}
|
|
}
|
|
|
|
TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel");
|
|
|
|
if (!kernel_size_correct) {
|
|
// If kernel size is incorrect
|
|
std::ostringstream input_ss;
|
|
std::ostringstream kernel_ss;
|
|
std::string separator;
|
|
|
|
for (int i = 0, len = input_shape.size(); i < len; ++i) {
|
|
input_ss << separator << input_shape[i];
|
|
kernel_ss << separator << kernel_shape[i];
|
|
separator = " x ";
|
|
}
|
|
|
|
TORCH_CHECK(false, "Calculated padded input size per channel: (", input_ss.str(), "). "
|
|
"Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
|
|
}
|
|
} else { // transposed
|
|
for (const auto i : c10::irange(2, k)) {
|
|
TORCH_CHECK(padding[i-2] <= (std::numeric_limits<T>::max() - padding[i-2]),
|
|
"Given padding=", padding[i-2], " at dimension ", i-2, " , expected padding to be at most ",
|
|
(std::numeric_limits<T>::max() / 2));
|
|
}
|
|
TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
|
|
"Given transposed=", transposed, ", weight of size ", weight_sizes,
|
|
", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],
|
|
" channels, but got ", at::symint::size<T>(input, 1), " channels instead");
|
|
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[1] * groups),
|
|
"Given transposed=", transposed, ", weight of size ", weight_sizes,
|
|
", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements",
|
|
", but got bias of size ", at::symint::sizes<T>(bias), " instead");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
static void check_shape_backward(
|
|
const at::Tensor& input,
|
|
const c10::ArrayRef<T>& weight_sizes,
|
|
const ConvParams<T>& params) {
|
|
check_shape_forward<T>(input, weight_sizes, /*bias=*/ Tensor(), params);
|
|
}
|
|
|
|
// Given an input tensor and an expected number of spatial dimensions, checks that the
|
|
// input is a valid shape and returns the batched form of the input.
|
|
//
|
|
// Args:
|
|
// input (Tensor): Input tensor
|
|
// num_spatial_dims (int): Number of spatial dimensions expected for the input
|
|
// func_name (string): Function name to produce a nice error message for invalid input
|
|
//
|
|
// Returns a std::tuple containing:
|
|
// batched_input (Tensor): Input with a batch dimension
|
|
// is_batched (bool): Indicates whether the original input was already batched
|
|
static std::tuple<Tensor, bool> batchify(
|
|
const Tensor& input,
|
|
const int64_t num_spatial_dims,
|
|
const std::string& func_name) {
|
|
// assume NTs are always batched
|
|
if (input.is_nested()) {
|
|
return std::make_tuple(input, true);
|
|
}
|
|
const auto dim_count_no_batch = num_spatial_dims + 1;
|
|
const auto dim_count_batch = dim_count_no_batch + 1;
|
|
const auto is_batched = (input.dim() == dim_count_batch);
|
|
TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
|
|
"Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
|
|
"D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
|
|
return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
|
|
}
|
|
|
|
static void check_input_same_type_as_parameters(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& bias) {
|
|
TORCH_CHECK(input.options().type_equal(weight.options()),
|
|
"Input type (", input.toString(), ") and weight type (", weight.toString(),
|
|
") should be the same");
|
|
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
|
|
"Input type (", input.toString(), ") and bias type (", bias.toString(),
|
|
") should be the same");
|
|
}
|
|
|
|
static void check_input_same_type_as_parameters(
|
|
const Tensor& input,
|
|
const Tensor& weight) {
|
|
check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor());
|
|
}
|
|
|
|
#if AT_MKLDNN_ENABLED()
|
|
static void check_input_same_type_as_parameters(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& bias,
|
|
const ConvBackend backend) {
|
|
if (backend == ConvBackend::Mkldnn || backend == ConvBackend::MkldnnTranspose) {
|
|
TORCH_CHECK(input.options().type_equal(weight.options())
|
|
|| (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
|
|
"Input type (", input.toString(), ") and weight type (", weight.toString(),
|
|
") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
|
|
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
|
|
|| (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
|
|
"Input type (", input.toString(), ") and bias type (", bias.toString(),
|
|
") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
|
|
} else {
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
static auto view4d(const at::Tensor& tensor) -> at::Tensor {
|
|
TORCH_CHECK(tensor.ndimension() == 3,
|
|
"expected 3D tensor, got tensor with ", tensor.ndimension(),
|
|
" dimensions instead");
|
|
return tensor.unsqueeze(2);
|
|
}
|
|
|
|
static auto view3d(const at::Tensor& tensor) -> at::Tensor {
|
|
TORCH_CHECK(tensor.ndimension() == 4,
|
|
"expected 4D tensor, got tensor with ", tensor.ndimension(),
|
|
" dimensions instead");
|
|
return tensor.squeeze(2);
|
|
}
|
|
|
|
static at::Tensor subtensor(at::Tensor& tensor, int64_t dim, int64_t groups, int64_t g) {
|
|
if (!tensor.defined()) {
|
|
return at::Tensor();
|
|
}
|
|
const auto memory_format = tensor.suggest_memory_format();
|
|
int64_t n = tensor.sizes()[dim] / groups;
|
|
return tensor.narrow(dim, n * g, n).contiguous(memory_format);
|
|
}
|
|
|
|
namespace {
|
|
|
|
std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) {
|
|
auto inp_view_as_complex = at::view_as_real(inp);
|
|
auto dim_i = inp_view_as_complex.dim() - 1;
|
|
auto i_r = inp_view_as_complex.select(dim_i, 0);
|
|
auto i_i = inp_view_as_complex.select(dim_i, 1);
|
|
return std::make_pair(i_r, i_i);
|
|
}
|
|
|
|
at::Tensor complex_convolution(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& bias,
|
|
SymIntArrayRef stride,
|
|
SymIntArrayRef padding,
|
|
SymIntArrayRef dilation,
|
|
bool transposed,
|
|
SymIntArrayRef output_padding,
|
|
const c10::SymInt& groups) {
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
auto [i_r, i_i] = complex_to_real(input.resolve_conj());
|
|
auto [w_r, w_i] = complex_to_real(weight.resolve_conj());
|
|
|
|
// [NOTE] Complex Convolution
|
|
// conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
|
|
// where W, x and b are all complex inputs.
|
|
// With Gauss Trick:
|
|
// a = conv(Wr, xr, br),
|
|
// b = conv(Wi, xi, 0),
|
|
// c = conv(Wr + Wi, xr + xi, bi + br)
|
|
// conv(W, x, b) = a - b + i(c - a - b)
|
|
Tensor a, b, c;
|
|
if (!bias.defined()) {
|
|
a = at::convolution_symint(i_r, w_r, bias, stride, padding, dilation, transposed, output_padding, groups);
|
|
b = at::convolution_symint(i_i, w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
|
|
c = at::convolution_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
|
|
} else {
|
|
auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
|
|
a = at::convolution_symint(i_r, w_r, b_r, stride, padding, dilation, transposed, output_padding, groups);
|
|
b = at::convolution_symint(i_i, w_i, Tensor(), stride, padding, dilation, transposed, output_padding, groups);
|
|
c = at::convolution_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, transposed, output_padding, groups);
|
|
}
|
|
|
|
auto i = c10::Scalar(c10::complex<double>(0, 1));
|
|
return a - b + i * (c - a - b);
|
|
}
|
|
|
|
at::Tensor complex_convolution_mode(
|
|
const at::Tensor& input,
|
|
const at::Tensor& weight,
|
|
const std::optional<at::Tensor>& bias_opt,
|
|
c10::SymIntArrayRef stride,
|
|
std::string_view padding,
|
|
c10::SymIntArrayRef dilation,
|
|
const c10::SymInt& groups) {
|
|
auto bias = bias_opt.value_or(Tensor());
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
auto [i_r, i_i] = complex_to_real(input.resolve_conj());
|
|
auto [w_r, w_i] = complex_to_real(weight.resolve_conj());
|
|
|
|
// See [NOTE] Complex Convolution
|
|
Tensor a, b, c;
|
|
if (!bias.defined()) {
|
|
a = at::_convolution_mode_symint(i_r, w_r, bias, stride, padding, dilation, groups);
|
|
b = at::_convolution_mode_symint(i_i, w_i, bias, stride, padding, dilation, groups);
|
|
c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups);
|
|
} else {
|
|
auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
|
|
a = at::_convolution_mode_symint(i_r, w_r, b_r, stride, padding, dilation, groups);
|
|
b = at::_convolution_mode_symint(i_i, w_i, Tensor(), stride, padding, dilation, groups);
|
|
c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups);
|
|
}
|
|
|
|
auto i = c10::Scalar(c10::complex<double>(0, 1));
|
|
return a - b + i * (c - a - b);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
at::Tensor conv1d_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
TORCH_CHECK(
|
|
!bias.defined() || bias.dtype() == input_.dtype(),
|
|
"Input type (",
|
|
input_.dtype().name(),
|
|
") and bias type (",
|
|
bias.dtype().name(),
|
|
") should be the same");
|
|
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
|
|
} else {
|
|
output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {0}, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv2d_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
TORCH_CHECK(
|
|
!bias.defined() || bias.dtype() == input_.dtype(),
|
|
"Input type (",
|
|
input_.dtype().name(),
|
|
") and bias type (",
|
|
bias.dtype().name(),
|
|
") should be the same");
|
|
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
|
|
} else {
|
|
output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv3d_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
TORCH_CHECK(
|
|
!bias.defined() || bias.dtype() == input_.dtype(),
|
|
"Input type (",
|
|
input_.dtype().name(),
|
|
") and bias type (",
|
|
bias.dtype().name(),
|
|
") should be the same");
|
|
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
|
|
} else {
|
|
output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
|
|
static Tensor convolution_same(
|
|
const Tensor &input, const Tensor &weight, const Tensor &bias,
|
|
SymIntArrayRef stride, SymIntArrayRef dilation, const c10::SymInt& groups) {
|
|
|
|
auto k = weight.dim();
|
|
TORCH_CHECK(k > 2, "weight should have at least three dimensions");
|
|
TORCH_CHECK(groups > 0, "non-positive groups is not supported");
|
|
auto dim = static_cast<size_t>(k - 2);
|
|
auto weight_sizes = weight.sym_sizes();
|
|
auto input_sizes = input.sym_sizes();
|
|
TORCH_CHECK(k == input.dim(),
|
|
"Expected ", k, "-dimensional input for ",
|
|
k, "-dimensional weight", weight_sizes, ", but got ",
|
|
input.dim(), "-dimensional input of size ",
|
|
input.sizes(), " instead");
|
|
TORCH_CHECK(stride.size() == dim || stride.size() == 1U,
|
|
"stride cannot broadcast to ", dim, " dimensions");
|
|
TORCH_CHECK(dilation.size() == dim || dilation.size() == 1U,
|
|
"dilation cannot broadcast to ", dim, " dimensions");
|
|
for (auto i: c10::irange(stride.size())) {
|
|
TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions");
|
|
}
|
|
|
|
// Calculate the correct padding
|
|
SymDimVector padding_l, padding_r;
|
|
bool symmetric_padding = true;
|
|
for (auto i: c10::irange(dim)) {
|
|
auto s = stride.size() == 1 ? stride[0] : stride[i];
|
|
auto d = dilation.size() == 1 ? dilation[0] : dilation[i];
|
|
auto pad = pooling_same_mode_padding_lr(
|
|
input_sizes[i + 2], weight_sizes[i + 2], s, d);
|
|
padding_l.push_back(pad.first);
|
|
padding_r.push_back(pad.second);
|
|
if (pad.first != pad.second) {
|
|
symmetric_padding = false;
|
|
}
|
|
}
|
|
|
|
if (symmetric_padding) {
|
|
// All backends handle symmetric padding natively
|
|
SymDimVector output_padding(dim);
|
|
return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
|
|
false, output_padding, groups);
|
|
}
|
|
|
|
TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may"
|
|
" require a zero-padded copy of the input be created");
|
|
SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
|
|
for (auto i: c10::irange(dim)) {
|
|
// Apply padding by the difference, leaving only a symmetric padding
|
|
auto delta_pad = padding_r[i] - padding_l[i];
|
|
auto pad_idx = 2 * (dim - 1 - i); // F.pad goes from last dim to first
|
|
if (delta_pad > 0) {
|
|
pad_nd[pad_idx + 1] = delta_pad;
|
|
} else {
|
|
pad_nd[pad_idx] = delta_pad;
|
|
padding_l[i] = padding_r[i];
|
|
}
|
|
}
|
|
auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
|
|
SymDimVector output_padding(dim);
|
|
return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
|
|
dilation, false, output_padding, groups);
|
|
}
|
|
|
|
Tensor _convolution_mode_symint(
|
|
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, std::string_view padding, SymIntArrayRef dilation,
|
|
c10::SymInt groups) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
if (padding == "same") {
|
|
return at::native::convolution_same(
|
|
input, weight, bias, stride, dilation, groups);
|
|
} else if (padding == "valid") {
|
|
return at::convolution_symint(
|
|
input, weight, bias, stride, {{0}}, dilation, false, {{0}}, groups);
|
|
}
|
|
TORCH_CHECK(false, "Invalid padding string: '", padding, "'");
|
|
}
|
|
|
|
at::Tensor conv1d_padding_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
|
|
c10::SymIntArrayRef stride, std::string_view padding, c10::SymIntArrayRef dilation,
|
|
c10::SymInt groups) {
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
|
|
} else {
|
|
output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv2d_padding_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
|
|
c10::SymIntArrayRef stride, std::string_view padding, c10::SymIntArrayRef dilation,
|
|
c10::SymInt groups) {
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
|
|
} else {
|
|
output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv3d_padding_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
|
|
c10::SymIntArrayRef stride, std::string_view padding, c10::SymIntArrayRef dilation,
|
|
c10::SymInt groups) {
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
|
|
} else {
|
|
output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv_transpose1d_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution(
|
|
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
|
} else {
|
|
output = at::convolution_symint(
|
|
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv_transpose2d_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution(
|
|
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
|
} else {
|
|
output = at::convolution_symint(
|
|
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor conv_transpose3d_symint(
|
|
const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
|
|
Tensor output;
|
|
if (at::isComplexType(input_.scalar_type())) {
|
|
output = complex_convolution(
|
|
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
|
} else {
|
|
output = at::convolution_symint(
|
|
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
|
|
}
|
|
return is_batched ? std::move(output) : output.squeeze(0);
|
|
}
|
|
|
|
at::Tensor convolution(
|
|
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
auto& ctx = at::globalContext();
|
|
// See Note [Enabling Deterministic Operations]
|
|
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
|
return at::_convolution(input, weight, bias, stride, padding, dilation,
|
|
transposed, output_padding, groups,
|
|
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN(at::Float32Op::CONV));
|
|
}
|
|
|
|
at::Tensor convolution_overrideable(
|
|
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups) {
|
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
|
|
}
|
|
|
|
// Function to select the convolution backend based on the inputs and params.
|
|
// This overload is used within the convolution internals but not exposed to python.
|
|
// NB: The forward pass provides a bias tensor while the backward pass provides
|
|
// a bool indicating whether the bias is defined. This is done to save memory by
|
|
// avoiding saving the full bias tensor for backward.
|
|
template <typename T>
|
|
static ConvBackend _select_conv_backend(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const std::optional<Tensor>& bias,
|
|
const at::OptionalArrayRef<T> bias_sizes_opt,
|
|
const bool need_backward,
|
|
const ConvParams<T>& params) {
|
|
|
|
// don't send empty inputs through backends
|
|
if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
|
|
return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
|
|
} else if (at::symint::numel<T>(input) == 0) {
|
|
TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
|
|
}
|
|
|
|
if (params.is_depthwise(input, weight)) {
|
|
if (params.use_cudnn_depthwise(input, weight)) {
|
|
return ConvBackend::Cudnn;
|
|
} else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
|
|
return ConvBackend::MiopenDepthwise;
|
|
} else {
|
|
if (input.ndimension() == 4) {
|
|
return ConvBackend::CudaDepthwise2d;
|
|
} else if (input.ndimension() == 5) {
|
|
return ConvBackend::CudaDepthwise3d;
|
|
} else {
|
|
// unsupported
|
|
}
|
|
}
|
|
} else if (params.use_cudnn(input, weight)) {
|
|
if (params.transposed) {
|
|
return ConvBackend::CudnnTranspose;
|
|
} else {
|
|
return ConvBackend::Cudnn;
|
|
}
|
|
} else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
|
|
if (params.transposed) {
|
|
return ConvBackend::MiopenTranspose;
|
|
} else {
|
|
return ConvBackend::Miopen;
|
|
}
|
|
} else if (params.use_mkldnn(input, weight)) {
|
|
if (params.transposed) {
|
|
return ConvBackend::MkldnnTranspose;
|
|
} else {
|
|
return ConvBackend::Mkldnn;
|
|
}
|
|
} else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
|
|
// Using prepacked conv is preferred, but XNNPACK is still the fastest
|
|
// option for NHWC.
|
|
return ConvBackend::Xnnpack2d;
|
|
// 3x3 depthwith convolutions implementation is inference only
|
|
} else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
|
|
return ConvBackend::Winograd3x3Depthwise;
|
|
} else if (
|
|
!params.transposed && (input.ndimension() == 5) &&
|
|
(input.device().is_cpu()) &&
|
|
!params.is_dilated()) {
|
|
// fast path for grouped conv3d
|
|
return ConvBackend::Slow3d;
|
|
} else if (input.device().is_cpu() || input.is_cuda()) {
|
|
// backends without support for groups
|
|
if (params.transposed) {
|
|
if (input.ndimension() == 4) {
|
|
return ConvBackend::SlowTranspose2d;
|
|
} else if (input.ndimension() == 5) {
|
|
return ConvBackend::SlowTranspose3d;
|
|
} else {
|
|
// unsupported
|
|
}
|
|
} else { /* Not transposed */
|
|
if (input.ndimension() == 4) {
|
|
if (params.is_dilated()) {
|
|
return ConvBackend::SlowDilated2d;
|
|
} else { /* dim == 4, non-dilated */
|
|
if (params.use_nnpack(input, weight)) {
|
|
return ConvBackend::NnpackSpatial;
|
|
} else {
|
|
/* CPU implementation has specialized MM kernels
|
|
for non-dilated case here */
|
|
return ConvBackend::Slow2d;
|
|
}
|
|
}
|
|
} else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
|
|
return ConvBackend::SlowDilated3d;
|
|
} else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
|
|
/* CPU implementation has specialized MM kernels
|
|
for non-dilated case here */
|
|
return ConvBackend::Slow3d;
|
|
} else {
|
|
// unsupported
|
|
}
|
|
}
|
|
} else if (params.use_mps(input, weight)) {
|
|
if (params.transposed) {
|
|
return ConvBackend::MpsTranspose;
|
|
} else {
|
|
return ConvBackend::Mps;
|
|
}
|
|
} else {
|
|
// Only reach here when input is backend with out-of-source implementation.
|
|
return ConvBackend::Overrideable;
|
|
}
|
|
|
|
// Error out if no suitable backend was found.
|
|
TORCH_CHECK(false, "unsupported ConvNd parameters");
|
|
}
|
|
|
|
// Selects a backend for convolution based on the inputs and params.
|
|
ConvBackend select_conv_backend(
|
|
const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_opt,
|
|
SymIntArrayRef stride_, SymIntArrayRef padding_, SymIntArrayRef dilation_,
|
|
bool transposed_, SymIntArrayRef output_padding_, c10::SymInt groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) {
|
|
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
|
const Tensor& bias = *bias_maybe_owned;
|
|
|
|
auto& ctx = at::globalContext();
|
|
auto k = weight_r.ndimension();
|
|
int64_t dim = k - 2;
|
|
ConvParams<c10::SymInt> params;
|
|
params.stride = expand_param_if_needed(stride_, "stride", dim);
|
|
params.padding = expand_param_if_needed(padding_, "padding", dim);
|
|
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
|
|
params.transposed = transposed_;
|
|
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
|
|
params.groups = std::move(groups_);
|
|
params.benchmark = ctx.benchmarkCuDNN();
|
|
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
|
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
|
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
|
|
|
auto input = input_r;
|
|
auto weight = weight_r;
|
|
check_shape_forward(input, weight.sym_sizes(), bias, params);
|
|
|
|
// Expand 1d -> 2d.
|
|
// This is only done for backends that don't natively support 1d spatial input.
|
|
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
|
|
// avoid accidentally going through NHWC for permuted 3d input.
|
|
input = input.contiguous();
|
|
params.view1d_as_2d();
|
|
input = view4d(input);
|
|
weight = view4d(weight);
|
|
}
|
|
|
|
auto bias_sizes = bias.defined() ? std::optional<SymIntArrayRef>(bias.sym_sizes()) : bias_sizes_opt;
|
|
bool need_backward = GradMode::is_enabled() &&
|
|
(input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
|
|
return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params);
|
|
}
|
|
|
|
// For BC reasons, have a copy that does not require bias_opt
|
|
static ConvBackend select_conv_backend(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const at::OptionalIntArrayRef bias_sizes_opt,
|
|
const bool need_backward,
|
|
const ConvParams<int64_t>& params) {
|
|
return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params);
|
|
}
|
|
|
|
static at::Tensor _convolution_nogroup_backend(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const Tensor& bias,
|
|
const ConvBackend backend,
|
|
const ConvParams<int64_t>& params) {
|
|
auto kernel_size = weight.sizes().slice(2);
|
|
switch(backend) {
|
|
case ConvBackend::NnpackSpatial:
|
|
#if AT_NNPACK_ENABLED()
|
|
return at::_nnpack_spatial_convolution(input, weight, bias, params.padding, params.stride);
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "NnpackSpatial backend was selected in PyTorch compiled without nnpack support");
|
|
#endif
|
|
case ConvBackend::Slow2d:
|
|
return at::thnn_conv2d(input, weight, kernel_size, bias, params.stride, params.padding);
|
|
case ConvBackend::SlowDilated2d:
|
|
return at::slow_conv_dilated2d(
|
|
input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
|
|
case ConvBackend::SlowDilated3d:
|
|
return at::slow_conv_dilated3d(
|
|
input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
|
|
case ConvBackend::SlowTranspose2d:
|
|
return at::slow_conv_transpose2d(
|
|
input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
|
|
case ConvBackend::SlowTranspose3d:
|
|
return at::slow_conv_transpose3d(
|
|
input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
|
|
default:
|
|
TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
|
|
}
|
|
}
|
|
|
|
static inline std::vector<int64_t> calc_output_size(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const ConvParams<int64_t>& params) {
|
|
std::vector<int64_t> output_size = params.transposed ?
|
|
conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups) :
|
|
conv_output_size(input.sizes(), weight.sizes(), params.padding, params.stride, params.dilation);
|
|
|
|
// Handle empty # of channels.
|
|
if (input.size(input_channels_dim) == 0) {
|
|
output_size[output_channels_dim] = 0;
|
|
}
|
|
return output_size;
|
|
}
|
|
|
|
static inline at::MemoryFormat determine_backend_memory_format(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const ConvBackend backend) {
|
|
auto backend_memory_format = at::MemoryFormat::Contiguous;
|
|
#if !defined(C10_MOBILE)
|
|
auto k = weight.ndimension();
|
|
// See Note [Mobile check segfaults]
|
|
switch(backend) {
|
|
case ConvBackend::Cudnn:
|
|
case ConvBackend::CudnnTranspose:
|
|
if (detail::getCUDAHooks().compiledWithCuDNN()) {
|
|
backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
|
|
}
|
|
break;
|
|
case ConvBackend::Miopen:
|
|
case ConvBackend::MiopenDepthwise:
|
|
case ConvBackend::MiopenTranspose:
|
|
if (detail::getCUDAHooks().compiledWithMIOpen()) {
|
|
backend_memory_format = miopen_conv_suggest_memory_format(input, weight);
|
|
}
|
|
break;
|
|
case ConvBackend::Mkldnn:
|
|
case ConvBackend::MkldnnTranspose:
|
|
if (mkldnn_conv_use_channels_last(input, weight)) {
|
|
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
|
|
}
|
|
break;
|
|
case ConvBackend::Slow2d:
|
|
case ConvBackend::SlowDilated2d:
|
|
case ConvBackend::SlowTranspose2d:
|
|
if (thnn_conv_use_channels_last(input, weight)) {
|
|
backend_memory_format = at::MemoryFormat::ChannelsLast;
|
|
}
|
|
break;
|
|
case ConvBackend::Overrideable:
|
|
if (xpu_conv_use_channels_last(input, weight)) {
|
|
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
|
|
}
|
|
break;
|
|
case ConvBackend::Mps:
|
|
case ConvBackend::MpsTranspose:
|
|
if (mps_conv_use_channels_last(input, weight)) {
|
|
backend_memory_format = (k == 5) ? MemoryFormat::ChannelsLast3d : MemoryFormat::ChannelsLast;
|
|
}
|
|
break;
|
|
default:
|
|
backend_memory_format = at::MemoryFormat::Contiguous;
|
|
}
|
|
#endif
|
|
return backend_memory_format;
|
|
}
|
|
|
|
at::MemoryFormat _determine_backend_memory_format(
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const ConvBackend backend) {
|
|
return determine_backend_memory_format(input, weight, backend);
|
|
}
|
|
|
|
at::Tensor _convolution(
|
|
const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
|
|
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
|
|
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
|
|
bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
|
|
const Tensor& bias_r = *bias_r_maybe_owned;
|
|
|
|
auto input = input_r;
|
|
auto weight = weight_r;
|
|
auto bias = bias_r;
|
|
auto k = weight.ndimension();
|
|
c10::IntArrayRef weight_sizes = weight.sizes();
|
|
int64_t dim = k - 2;
|
|
|
|
TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
|
|
TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");
|
|
|
|
ConvParams<int64_t> params;
|
|
params.stride = expand_param_if_needed(stride_, "stride", dim);
|
|
params.padding = expand_param_if_needed(padding_, "padding", dim);
|
|
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
|
|
params.transposed = transposed_;
|
|
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
|
|
params.groups = groups_;
|
|
params.benchmark = benchmark;
|
|
params.deterministic = deterministic;
|
|
params.cudnn_enabled = cudnn_enabled;
|
|
params.allow_tf32 = allow_tf32;
|
|
|
|
check_shape_forward(input, weight_sizes, bias, params);
|
|
|
|
// Expand 1d -> 2d.
|
|
// This is only done for backends that don't natively support 1d spatial input.
|
|
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
|
|
// avoid accidentally going through NHWC for permuted 3d input.
|
|
input = input.contiguous();
|
|
params.view1d_as_2d();
|
|
input = view4d(input);
|
|
weight = view4d(weight);
|
|
}
|
|
|
|
// Select appropriate backend to use.
|
|
auto bias_sizes_opt = bias.defined() ? std::optional<IntArrayRef>(bias.sizes()) : std::nullopt;
|
|
bool need_backward = GradMode::is_enabled() &&
|
|
(input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
|
|
ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
|
|
at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
|
|
|
|
// Call the backend.
|
|
Tensor output;
|
|
auto kernel_size = weight.sizes().slice(2);
|
|
switch (backend) {
|
|
case ConvBackend::CudaDepthwise2d:
|
|
output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
|
|
params.stride, params.padding, params.dilation);
|
|
break;
|
|
case ConvBackend::CudaDepthwise3d:
|
|
output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
|
|
params.stride, params.padding, params.dilation);
|
|
break;
|
|
case ConvBackend::Cudnn:
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
output = at::cudnn_convolution(
|
|
input.contiguous(backend_memory_format), weight, params.padding, params.stride,
|
|
params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
|
|
if (bias.defined()) {
|
|
output.add_(reshape_bias(input.dim(), bias));
|
|
}
|
|
break;
|
|
case ConvBackend::CudnnTranspose:
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
output = at::cudnn_convolution_transpose(
|
|
input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
|
|
if (bias.defined()) {
|
|
output.add_(reshape_bias(input.dim(), bias));
|
|
}
|
|
break;
|
|
case ConvBackend::Empty:
|
|
{
|
|
Tensor weight_view;
|
|
// Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where
|
|
// view size is not compatible with input tensor's size and stride.
|
|
if(weight.is_contiguous()) {
|
|
weight_view = at::_unsafe_view(weight, -1);
|
|
} else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) {
|
|
weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1);
|
|
} else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
|
|
weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1);
|
|
} else {
|
|
weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1);
|
|
}
|
|
|
|
output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
|
|
if (bias.defined()) {
|
|
output.add_(bias[0]);
|
|
}
|
|
output = output.view(calc_output_size(input, weight, params));
|
|
break;
|
|
}
|
|
case ConvBackend::Miopen:
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
output = at::miopen_convolution(
|
|
input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
|
|
params.dilation, params.groups, params.benchmark, params.deterministic);
|
|
break;
|
|
case ConvBackend::MiopenDepthwise:
|
|
output = at::miopen_depthwise_convolution(
|
|
input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
|
|
params.dilation, params.groups, params.benchmark, params.deterministic);
|
|
break;
|
|
case ConvBackend::MiopenTranspose:
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
output = at::miopen_convolution_transpose(
|
|
input.contiguous(backend_memory_format), weight, bias, params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
|
|
break;
|
|
case ConvBackend::Mkldnn:
|
|
#if AT_MKLDNN_ENABLED()
|
|
check_input_same_type_as_parameters(input, weight, bias, backend);
|
|
if (!input.is_mkldnn()) {
|
|
// need to ensure contiguous for non-mkldnn tensors
|
|
input = input.contiguous(backend_memory_format);
|
|
weight = weight.contiguous(backend_memory_format);
|
|
bias = bias.defined() ? bias.contiguous() : bias;
|
|
}
|
|
output = at::mkldnn_convolution(
|
|
input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
|
|
#endif
|
|
break;
|
|
case ConvBackend::MkldnnTranspose:
|
|
#if AT_MKLDNN_ENABLED()
|
|
check_input_same_type_as_parameters(input, weight, bias, backend);
|
|
if (!input.is_mkldnn()) {
|
|
// need to ensure contiguous for non-mkldnn tensors
|
|
input = input.contiguous(backend_memory_format);
|
|
weight = weight.contiguous(backend_memory_format);
|
|
bias = bias.defined() ? bias.contiguous() : bias;
|
|
}
|
|
output = mkldnn_convolution_transpose_stub(input.device().type(),
|
|
input, weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups);
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
|
|
#endif
|
|
break;
|
|
case ConvBackend::MkldnnEmpty:
|
|
#if AT_MKLDNN_ENABLED()
|
|
output = empty_mkldnn(
|
|
calc_output_size(input, weight, params), optTypeMetaToScalarType(input.options().dtype_opt()),
|
|
input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
|
|
#endif
|
|
break;
|
|
case ConvBackend::Overrideable:
|
|
output = at::convolution_overrideable(
|
|
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed,
|
|
params.output_padding, params.groups);
|
|
break;
|
|
case ConvBackend::Slow3d:
|
|
output = at::slow_conv3d(input, weight, kernel_size, bias, params.stride, params.padding);
|
|
break;
|
|
case ConvBackend::Winograd3x3Depthwise:
|
|
output = convolution_depthwise3x3_winograd_stub(
|
|
input.device().type(), input, weight, bias, params.stride, params.padding, params.groups);
|
|
break;
|
|
case ConvBackend::Xnnpack2d:
|
|
output = xnnpack::convolution2d(
|
|
input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
|
|
break;
|
|
// Handle backends that don't natively support groups > 1.
|
|
case ConvBackend::NnpackSpatial:
|
|
case ConvBackend::Slow2d:
|
|
case ConvBackend::SlowDilated2d:
|
|
case ConvBackend::SlowDilated3d:
|
|
case ConvBackend::SlowTranspose2d:
|
|
case ConvBackend::SlowTranspose3d:
|
|
input = input.contiguous(backend_memory_format);
|
|
weight = weight.contiguous(backend_memory_format);
|
|
if (params.groups == 1) {
|
|
output = _convolution_nogroup_backend(input, weight, bias, backend, params);
|
|
} else {
|
|
std::vector<Tensor> outputs(params.groups);
|
|
for (const auto g : c10::irange(params.groups)) {
|
|
auto input_g = subtensor(input, 1, params.groups, g);
|
|
auto weight_g = subtensor(weight, 0, params.groups, g);
|
|
auto bias_g = subtensor(bias, 0, params.groups, g);
|
|
outputs[g] = _convolution_nogroup_backend(input_g, weight_g, bias_g, backend, params);
|
|
}
|
|
output = at::cat(outputs, 1);
|
|
}
|
|
break;
|
|
case ConvBackend::Mps:
|
|
#ifdef USE_MPS
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
output = at::_mps_convolution(input, weight, bias.defined() ? bias.contiguous() : bias,
|
|
params.padding, params.stride, params.dilation,
|
|
params.groups);
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
|
|
#endif
|
|
break;
|
|
case ConvBackend::MpsTranspose:
|
|
#ifdef USE_MPS
|
|
check_input_same_type_as_parameters(input, weight, bias);
|
|
output = at::_mps_convolution_transpose(
|
|
input.contiguous(backend_memory_format), weight,
|
|
params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups);
|
|
if (bias.defined()) {
|
|
output.add_(reshape_bias(input.dim(), bias));
|
|
}
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
|
|
#endif
|
|
break;
|
|
}
|
|
|
|
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
|
|
output = view3d(output);
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
at::Tensor _convolution(
|
|
const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
|
|
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
|
|
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
|
|
bool benchmark, bool deterministic, bool cudnn_enabled)
|
|
{
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
|
|
const Tensor& bias_r = *bias_r_maybe_owned;
|
|
|
|
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN(at::Float32Op::CONV));
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
|
const Tensor& grad_output, const Tensor& input, const Tensor& weight,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
|
|
bool transposed, IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
|
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
|
|
at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
|
|
at::empty({}));
|
|
}
|
|
|
|
static Tensor subvariable(const Tensor& var, int64_t dim, int64_t groups, int64_t g) {
|
|
int64_t n = var.sizes()[dim] / groups;
|
|
auto result = var.narrow(dim, n * g, n);
|
|
return result;
|
|
}
|
|
|
|
std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const std::optional<Tensor>& ggI_opt, const std::optional<Tensor>& ggW_r_opt, const std::optional<Tensor>& ggb_opt,
|
|
const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
|
|
IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
|
|
bool transposed_, IntArrayRef output_padding_, int64_t groups_,
|
|
std::array<bool, 3> output_mask) {
|
|
// See [Note: hacky wrapper removal for optional tensor]
|
|
c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
|
|
const Tensor& ggI = *ggI_maybe_owned;
|
|
Tensor ggW = ggW_r_opt.value_or(Tensor());
|
|
const Tensor& ggb = ggb_opt.value_or(Tensor());
|
|
|
|
|
|
auto gO = gO_r;
|
|
auto weight = weight_r;
|
|
|
|
int64_t dim = weight.ndimension() - 2;
|
|
ConvParams<int64_t> params;
|
|
params.stride = expand_param_if_needed(stride_, "stride", dim);
|
|
params.padding = expand_param_if_needed(padding_, "padding", dim);
|
|
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
|
|
params.transposed = transposed_;
|
|
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
|
|
// TODO: hacky way of inferring the groups number for grouped Conv3D
|
|
// See: https://github.com/pytorch/pytorch/pull/36355
|
|
if (!params.transposed && input.dim() > 4) {
|
|
// Avoid undefined behavior when num channels == 0; params are unused for that case.
|
|
params.groups = (weight.size(1) > 0) ? input.size(1) / weight.size(1) : -1;
|
|
} else {
|
|
params.groups = groups_;
|
|
}
|
|
|
|
// Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
|
|
Tensor ggO;
|
|
if (input.numel() != 0) {
|
|
if (ggI.defined()) {
|
|
if (weight.is_cuda()) {
|
|
weight = weight.contiguous();
|
|
}
|
|
ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
|
|
}
|
|
|
|
if (ggW.defined()) {
|
|
if (ggW.is_cuda()) {
|
|
ggW = ggW.contiguous();
|
|
}
|
|
auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
|
|
if (ggO.defined()) {
|
|
ggO = ggO + ggW_term;
|
|
} else {
|
|
ggO = ggW_term;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (ggb.defined()) {
|
|
// View as (1, ggb.size(0), 1, 1...)
|
|
|
|
// Expand
|
|
std::vector<int64_t> new_size(gO.ndimension(), 1);
|
|
new_size[1] = ggb.sizes()[0];
|
|
auto ggb_contiguous = ggb.contiguous();
|
|
auto ggb_view = ggb_contiguous.view(new_size);
|
|
|
|
// Expand
|
|
auto ggb_expanded = ggb_view.expand(gO.sizes());
|
|
|
|
if (ggO.defined()) {
|
|
ggO = ggO + ggb_expanded;
|
|
} else {
|
|
ggO = ggb_expanded;
|
|
}
|
|
}
|
|
|
|
// Compute gW = conv(ggI, gO)
|
|
Tensor gW;
|
|
if (ggI.defined()) {
|
|
|
|
// Modified params with correct padding
|
|
ConvParams<int64_t> gw_conv_params(params);
|
|
|
|
// Disable groups as they are handled separately
|
|
auto groups = gw_conv_params.groups;
|
|
gw_conv_params.groups = 1;
|
|
std::swap(gw_conv_params.dilation, gw_conv_params.stride);
|
|
|
|
// Transpose gO and ggI to accumulate over batch
|
|
auto gOt = gO.transpose(0, 1);
|
|
auto ggIt = ggI.transpose(0, 1);
|
|
|
|
Tensor gWt;
|
|
// Compute conv
|
|
if (input.numel() != 0) {
|
|
if (groups == 1) {
|
|
|
|
if (gOt.is_cuda()) {
|
|
gOt = gOt.contiguous();
|
|
}
|
|
// Compute conv
|
|
if (params.transposed) {
|
|
gw_conv_params.transposed = false;
|
|
gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
|
|
} else {
|
|
gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
|
|
}
|
|
} else {
|
|
std::vector<Tensor> gWt_list(groups);
|
|
for (const auto g : c10::irange(groups)) {
|
|
auto ggIt_g = subvariable(ggIt, 0, groups, g);
|
|
auto gOt_g = subvariable(gOt, 0, groups, g);
|
|
if (gOt_g.is_cuda()) {
|
|
gOt_g = gOt_g.contiguous();
|
|
}
|
|
|
|
// Compute conv
|
|
if (params.transposed) {
|
|
gw_conv_params.transposed = false;
|
|
gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
|
|
} else {
|
|
gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
|
|
}
|
|
}
|
|
|
|
gWt = at::cat(gWt_list, 1);
|
|
}
|
|
|
|
// Transpose gW to match chan_in and chan_out
|
|
gW = gWt.transpose(0, 1);
|
|
|
|
// narrow gW to only relevant portion
|
|
// we do it this way instead of narrowing the input itself because
|
|
// the ConvForward kernels don't support asymmetric padding.
|
|
auto gW_size = gW.sizes();
|
|
auto w_size = weight.sizes();
|
|
for (const auto i : c10::irange(2, static_cast<int64_t>(gW_size.size()))) {
|
|
if (gW_size[i] > w_size[i]) {
|
|
gW = gW.narrow(i, 0, w_size[i]);
|
|
gW_size = gW.sizes();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Compute gI = convT(gO, ggW) if !transposed
|
|
// gI = conv(gO, ggw) if transposed
|
|
Tensor gI;
|
|
if (input.numel() != 0) {
|
|
if (ggW.defined()) {
|
|
ConvParams<int64_t> gi_conv_params(params);
|
|
gi_conv_params.transposed = !params.transposed;
|
|
|
|
if (params.transposed) {
|
|
if (gO.is_cuda()) {
|
|
gO = gO.contiguous();
|
|
}
|
|
gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
|
|
|
|
// narrow gI to only relevant portion
|
|
// we do it this way because negative output_padding is not supported
|
|
// TODO: figure out if we can narrow gO and save some compute,
|
|
// rather than narrowing the computed gI
|
|
auto gI_size = gI.sizes();
|
|
auto i_size = input.sizes();
|
|
for (const auto i : c10::irange(2, static_cast<int64_t>(gI_size.size()))) {
|
|
if (gI_size[i] > i_size[i]) {
|
|
gI = gI.narrow(i, 0, i_size[i]);
|
|
gI_size = gI.sizes();
|
|
}
|
|
}
|
|
} else {
|
|
// calculate output_padding
|
|
// TODO: figure out why this needs to be computed...
|
|
auto kernel_size = weight.sizes().slice(2);
|
|
auto input_shape = input.sizes().slice(2);
|
|
auto grad_output_shape = gO.sizes().slice(2);
|
|
|
|
for (const auto i : c10::irange(kernel_size.size())) {
|
|
// Check if whole input has been used or not
|
|
auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i]
|
|
- 2 * gi_conv_params.padding[i]
|
|
+ (gi_conv_params.stride[i] * (grad_output_shape[i] - 1) + 1);
|
|
if (expected_input_shape != input_shape[i]) {
|
|
gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape;
|
|
}
|
|
}
|
|
|
|
if (gO.is_cuda()) {
|
|
gO = gO.contiguous();
|
|
}
|
|
|
|
gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
|
|
}
|
|
}
|
|
}
|
|
|
|
return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
|
|
}
|
|
|
|
static std::tuple<at::Tensor, at::Tensor, at::Tensor> _convolution_backward_nogroup_backend(
|
|
const Tensor& grad_output,
|
|
const Tensor& input,
|
|
const Tensor& weight,
|
|
const std::array<bool, 3> output_mask,
|
|
const ConvBackend backend,
|
|
const ConvParams<int64_t>& params) {
|
|
auto kernel_size = weight.sizes().slice(2);
|
|
switch(backend) {
|
|
case ConvBackend::Slow2d:
|
|
return at::_slow_conv2d_backward(
|
|
grad_output, input, weight, kernel_size, params.stride, params.padding, output_mask);
|
|
// NB: nnpack backward does not support strided convolutions; use slow impl instead
|
|
case ConvBackend::NnpackSpatial:
|
|
case ConvBackend::SlowDilated2d:
|
|
return slow_conv_dilated2d_backward_stub(
|
|
input.device().type(),
|
|
grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
|
|
case ConvBackend::SlowDilated3d:
|
|
return slow_conv_dilated3d_backward_stub(
|
|
input.device().type(),
|
|
grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
|
|
case ConvBackend::SlowTranspose2d:
|
|
return slow_conv_transpose2d_backward_stub(
|
|
input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
|
|
params.output_padding, params.dilation, output_mask);
|
|
case ConvBackend::SlowTranspose3d:
|
|
return slow_conv_transpose3d_backward_stub(
|
|
input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
|
|
params.output_padding, params.dilation, output_mask);
|
|
default:
|
|
TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
|
|
}
|
|
}
|
|
|
|
// Backward pass for convolution. Computes gradients for input, weight, and bias depending on the
|
|
// output_mask setting. This function supports 1D, 2D, or 3D spatial convolution and currently requires
|
|
// a single batch dimension to be present.
|
|
//
|
|
// Args:
|
|
// grad_output_: tensor of shape (N, C_out, L_out), (N, C_out, H_out, W_out), or (N, C_out, D_out, H_out, W_out)
|
|
// input_: tensor of shape (N, C_in, L_in), (N, C_in, H_in, W_in), or (N, C_in, D_in, H_in, W_in)
|
|
// weight_: tensor of shape (C_out, C_in // groups, *kernel_size); dimension of kernel_size must match the number
|
|
// of input spatial dimensions
|
|
// bias_sizes_opt: if specified, indicates that a bias was used in the forward pass and contains the shape
|
|
// of the bias. While the bias shape can be computed from other inputs, it is provided to this function for
|
|
// ease of use. The bias shape is (weight.shape[0]) for normal convolution and (weight.shape[1] * groups)
|
|
// for transposed convolution.
|
|
// stride: single value or an array with dimension matching the number of input spatial dimensions
|
|
// padding: single value or an array with dimension matching the number of input spatial dimensions
|
|
// dilation: single value or an array with dimension matching the number of input spatial dimensions
|
|
// transposed: boolean indicating whether the convolution is transposed
|
|
// output_padding: single value or dimension == number of input spatial dimensions; only supported when
|
|
// transposed is true
|
|
// groups: number of groups for grouped convolution
|
|
// output_mask: 3-dim boolean array specifying which gradients to compute in input, weight, bias order
|
|
std::tuple<Tensor, Tensor, Tensor> convolution_backward(
|
|
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
|
|
const at::OptionalIntArrayRef bias_sizes_opt,
|
|
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
|
|
int64_t groups, std::array<bool, 3> output_mask) {
|
|
auto grad_output = grad_output_;
|
|
auto input = input_;
|
|
auto weight = weight_;
|
|
|
|
auto k = weight.ndimension();
|
|
int64_t dim = k - 2;
|
|
|
|
TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
|
|
|
|
auto& ctx = at::globalContext();
|
|
ConvParams<int64_t> params;
|
|
params.stride = expand_param_if_needed(stride, "stride", dim);
|
|
params.padding = expand_param_if_needed(padding, "padding", dim);
|
|
params.dilation = expand_param_if_needed(dilation, "dilation", dim);
|
|
params.transposed = transposed;
|
|
params.output_padding = expand_param_if_needed(output_padding, "output_padding", dim);
|
|
params.groups = groups;
|
|
params.benchmark = ctx.benchmarkCuDNN();
|
|
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
|
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
|
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
|
|
|
// Validate inputs.
|
|
check_shape_backward(input, weight.sizes(), params);
|
|
TORCH_CHECK(input.dim() == grad_output.dim(),
|
|
"Expected input and grad_output to have the same number of dimensions, but got: ",
|
|
input.dim(), " and ", grad_output.dim());
|
|
|
|
// output_padding is only supported for transposed convolutions
|
|
if (!params.transposed) {
|
|
for (auto pad : params.output_padding) {
|
|
TORCH_CHECK(pad == 0, "output_padding is not supported for non-transposed convolutions; got: ",
|
|
params.output_padding);
|
|
}
|
|
}
|
|
|
|
// Expand 1d -> 2d.
|
|
// This is only done for backends that don't natively support 1d spatial input.
|
|
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
|
|
// avoid accidentally going through NHWC for permuted 3d input.
|
|
input = input.contiguous();
|
|
params.view1d_as_2d();
|
|
grad_output = view4d(grad_output);
|
|
input = view4d(input);
|
|
weight = view4d(weight);
|
|
}
|
|
|
|
// Select appropriate backend to use.
|
|
ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params);
|
|
at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
|
|
|
|
// Call the backend.
|
|
Tensor backend_grad_input, backend_grad_weight, backend_grad_bias;
|
|
auto kernel_size = weight.sizes().slice(2);
|
|
switch(backend) {
|
|
case ConvBackend::CudaDepthwise2d:
|
|
{
|
|
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
|
|
std::tie(backend_grad_input, backend_grad_weight) =
|
|
conv_depthwise2d_backward_stub(input.device().type(), grad_output, input,
|
|
weight, kernel_size, params.stride, params.padding, params.dilation, input_weight_output_mask);
|
|
break;
|
|
}
|
|
case ConvBackend::CudaDepthwise3d:
|
|
TORCH_CHECK(input.ndimension() == 5);
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
conv_depthwise3d_backward_stub(
|
|
input.device().type(), grad_output, input, weight, kernel_size, params.stride,
|
|
params.padding, params.dilation, output_mask);
|
|
break;
|
|
case ConvBackend::Cudnn:
|
|
{
|
|
check_input_same_type_as_parameters(input, weight);
|
|
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
|
|
std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_backward_stub(
|
|
input.device().type(),
|
|
// Only make input contiguous when it is necessary for the backwards computation
|
|
output_mask[1] ? input.contiguous(backend_memory_format) : input,
|
|
grad_output, weight, params.padding, params.stride,
|
|
params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
|
|
input_weight_output_mask);
|
|
break;
|
|
}
|
|
case ConvBackend::Mps:
|
|
{
|
|
#ifdef USE_MPS
|
|
check_input_same_type_as_parameters(input, weight);
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
at::mps_convolution_backward(input, grad_output, weight, params.padding,
|
|
params.stride, params.dilation, params.groups, output_mask);
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
|
|
#endif
|
|
break;
|
|
}
|
|
case ConvBackend::MpsTranspose:
|
|
{
|
|
#ifdef USE_MPS
|
|
check_input_same_type_as_parameters(input, weight);
|
|
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
|
|
std::tie(backend_grad_input, backend_grad_weight) = at::mps_convolution_transpose_backward(
|
|
// Only make input contiguous when it is necessary for the backwards computation
|
|
output_mask[1] ? input.contiguous(backend_memory_format) : input,
|
|
grad_output, weight, params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups, input_weight_output_mask);
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
|
|
#endif
|
|
break;
|
|
}
|
|
case ConvBackend::CudnnTranspose:
|
|
{
|
|
check_input_same_type_as_parameters(input, weight);
|
|
std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
|
|
std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub(
|
|
input.device().type(),
|
|
// Only make input contiguous when it is necessary for the backwards computation
|
|
output_mask[1] ? input.contiguous(backend_memory_format) : input,
|
|
grad_output, weight, params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
|
|
input_weight_output_mask);
|
|
break;
|
|
}
|
|
case ConvBackend::Empty:
|
|
if (output_mask[0]) {
|
|
backend_grad_input = at::zeros_like(input);
|
|
}
|
|
if (output_mask[1]) {
|
|
backend_grad_weight = at::zeros_like(weight);
|
|
}
|
|
if (output_mask[2]) {
|
|
backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
|
|
}
|
|
break;
|
|
case ConvBackend::MkldnnEmpty:
|
|
#if AT_MKLDNN_ENABLED()
|
|
if (output_mask[0]) {
|
|
if (input.is_mkldnn()) {
|
|
backend_grad_input = empty_mkldnn(input.sizes(), optTypeMetaToScalarType(input.options().dtype_opt()),
|
|
input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
|
|
backend_grad_input.zero_();
|
|
} else {
|
|
backend_grad_input = at::zeros_like(input);
|
|
}
|
|
}
|
|
if (output_mask[1]) {
|
|
// mkldnn weight is not supported during training by the mkldnn backend
|
|
backend_grad_weight = at::zeros_like(weight);
|
|
}
|
|
if (output_mask[2]) {
|
|
// mkldnn bias is not supported during training by the mkldnn backend
|
|
backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
|
|
}
|
|
#else
|
|
TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
|
|
#endif
|
|
break;
|
|
case ConvBackend::Miopen:
|
|
check_input_same_type_as_parameters(input, weight);
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
miopen_convolution_backward_stub(
|
|
input.device().type(),
|
|
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
|
|
params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
|
|
break;
|
|
case ConvBackend::MiopenDepthwise:
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
miopen_depthwise_convolution_backward_stub(
|
|
input.device().type(),
|
|
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
|
|
params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
|
|
break;
|
|
case ConvBackend::MiopenTranspose:
|
|
check_input_same_type_as_parameters(input, weight);
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
miopen_convolution_transpose_backward_stub(
|
|
input.device().type(),
|
|
input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding,
|
|
params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
|
|
break;
|
|
case ConvBackend::Mkldnn:
|
|
TORCH_CHECK(!weight.is_mkldnn(),
|
|
"The MKLDNN backend does not support weight as an MKLDNN tensor during training");
|
|
if (!input.is_mkldnn()) {
|
|
input = input.contiguous(backend_memory_format);
|
|
weight = weight.contiguous(backend_memory_format);
|
|
}
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
|
|
params.stride, params.dilation, params.groups, output_mask);
|
|
break;
|
|
case ConvBackend::MkldnnTranspose:
|
|
TORCH_CHECK(!weight.is_mkldnn(),
|
|
"The MKLDNN backend does not support weight as an MKLDNN tensor during training");
|
|
if (!input.is_mkldnn()) {
|
|
input = input.contiguous(backend_memory_format);
|
|
weight = weight.contiguous(backend_memory_format);
|
|
}
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
mkldnn_convolution_transpose_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
|
|
params.output_padding, params.stride, params.dilation, params.groups, output_mask);
|
|
break;
|
|
case ConvBackend::Overrideable:
|
|
// Only reach here when input is backend with out-of-source implementation.
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
at::convolution_backward_overrideable(grad_output, input, weight, params.stride, params.padding,
|
|
params.dilation, params.transposed, params.output_padding, params.groups, output_mask);
|
|
break;
|
|
case ConvBackend::Slow3d:
|
|
// Note that no CUDA implementation of this kernel exists currently.
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
slow_conv3d_backward_cpu(
|
|
grad_output, input, weight, kernel_size,
|
|
params.stride, params.padding, output_mask);
|
|
break;
|
|
// Handle backends that don't natively support groups > 1.
|
|
case ConvBackend::NnpackSpatial:
|
|
case ConvBackend::Slow2d:
|
|
case ConvBackend::SlowDilated2d:
|
|
case ConvBackend::SlowDilated3d:
|
|
case ConvBackend::SlowTranspose2d:
|
|
case ConvBackend::SlowTranspose3d:
|
|
{
|
|
input = input.contiguous(backend_memory_format);
|
|
weight = weight.contiguous(backend_memory_format);
|
|
if (params.groups == 1) {
|
|
std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
|
|
_convolution_backward_nogroup_backend(
|
|
grad_output, input, weight, output_mask, backend, params);
|
|
} else {
|
|
std::vector<Tensor> backend_grad_inputs(params.groups);
|
|
std::vector<Tensor> backend_grad_weights(params.groups);
|
|
std::vector<Tensor> backend_grad_biases(params.groups);
|
|
for (int g = 0; g < params.groups; ++g) {
|
|
auto grad_output_g = subtensor(grad_output, 1, params.groups, g);
|
|
auto input_g = subtensor(input, 1, params.groups, g);
|
|
auto weight_g = subtensor(weight, 0, params.groups, g);
|
|
std::tie(backend_grad_inputs[g], backend_grad_weights[g], backend_grad_biases[g]) =
|
|
_convolution_backward_nogroup_backend(
|
|
grad_output_g, input_g, weight_g, output_mask, backend, params);
|
|
}
|
|
if (output_mask[0]) {
|
|
backend_grad_input = at::cat(backend_grad_inputs, 1);
|
|
}
|
|
if (output_mask[1]) {
|
|
backend_grad_weight = at::cat(backend_grad_weights, 0);
|
|
}
|
|
if (output_mask[2]) {
|
|
backend_grad_bias = at::cat(backend_grad_biases, 0);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
// Backward is not supported for these backends.
|
|
case ConvBackend::Winograd3x3Depthwise:
|
|
TORCH_CHECK(false, "Backward is not supported for depthwise 3x3 winograd");
|
|
break;
|
|
case ConvBackend::Xnnpack2d:
|
|
TORCH_CHECK(false, "Backward is not supported for xnnpack");
|
|
break;
|
|
}
|
|
|
|
// Convert 2D inputs back to 1D for backends that don't natively support 1D
|
|
// spatial inputs.
|
|
if (output_mask[0]) {
|
|
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
|
|
backend_grad_input = view3d(backend_grad_input);
|
|
}
|
|
}
|
|
if (output_mask[1]) {
|
|
if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
|
|
backend_grad_weight = view3d(backend_grad_weight);
|
|
}
|
|
}
|
|
if (output_mask[2]) {
|
|
if (!backend_grad_bias.defined()) {
|
|
// Calculate bias gradients outside of the backend for those that don't support it.
|
|
backend_grad_bias = grad_output.sum((dim == 3) ? IntArrayRef{0, 2, 3, 4} : IntArrayRef{0, 2, 3});
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias);
|
|
}
|
|
|
|
void _cudnn_set_conv_benchmark_empty_cache(bool enable) {
|
|
conv_benchmark_empty_cache = enable;
|
|
}
|
|
|
|
bool _cudnn_get_conv_benchmark_empty_cache() {
|
|
return conv_benchmark_empty_cache;
|
|
}
|
|
|
|
|
|
|
|
} // namespace at::native
|