mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR continues to fix clang-tidy warnings for headers in c10/core and c10/util. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115495 Approved by: https://github.com/malfet
291 lines
9.2 KiB
C++
291 lines
9.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/util/ArrayRef.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <cstdint>
|
|
#include <ostream>
|
|
#include <vector>
|
|
|
|
// Memory format is not the property of a Tensor. It is the way to tell an
|
|
// operator how the result should be organized in memory and nothing more. That
|
|
// means memory format should never be used as return value for any tensor state
|
|
// interrogation functions (internally and externally).
|
|
//
|
|
// Possible options are:
|
|
// Preserve:
|
|
// If any of the input tensors is in channels_last format, operator output
|
|
// should be in channels_last format
|
|
//
|
|
// Contiguous:
|
|
// Regardless of input tensors format, the output should be contiguous
|
|
// Tensor.
|
|
//
|
|
// ChannelsLast:
|
|
// Regardless of input tensors format, the output should be in channels_last
|
|
// format.
|
|
|
|
namespace c10 {
|
|
enum class MemoryFormat : int8_t {
|
|
Contiguous,
|
|
Preserve,
|
|
ChannelsLast,
|
|
ChannelsLast3d,
|
|
NumOptions
|
|
};
|
|
|
|
// If you are seeing this, it means that this call site was not checked if
|
|
// the memory format could be preserved, and it was switched to old default
|
|
// behaviour of contiguous
|
|
#define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
|
|
|
|
inline MemoryFormat get_contiguous_memory_format() {
|
|
return MemoryFormat::Contiguous;
|
|
}
|
|
|
|
inline std::ostream& operator<<(
|
|
std::ostream& stream,
|
|
at::MemoryFormat memory_format) {
|
|
switch (memory_format) {
|
|
case MemoryFormat::Preserve:
|
|
return stream << "Preserve";
|
|
case MemoryFormat::Contiguous:
|
|
return stream << "Contiguous";
|
|
case MemoryFormat::ChannelsLast:
|
|
return stream << "ChannelsLast";
|
|
case MemoryFormat::ChannelsLast3d:
|
|
return stream << "ChannelsLast3d";
|
|
default:
|
|
TORCH_CHECK(false, "Unknown memory format ", memory_format);
|
|
}
|
|
}
|
|
|
|
// Note: Hardcoded the channel last stride indices here to get better
|
|
// performance
|
|
template <typename T>
|
|
inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
|
|
std::vector<T> strides(sizes.size());
|
|
switch (sizes.size()) {
|
|
case 4:
|
|
strides[1] = 1;
|
|
strides[3] = sizes[1];
|
|
strides[2] = strides[3] * sizes[3];
|
|
strides[0] = strides[2] * sizes[2];
|
|
return strides;
|
|
case 3:
|
|
strides[0] = 1;
|
|
strides[2] = sizes[0];
|
|
strides[1] = strides[2] * sizes[2];
|
|
return strides;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "ChannelsLast2d doesn't support size ", sizes.size());
|
|
}
|
|
}
|
|
|
|
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
|
|
return get_channels_last_strides_2d<int64_t>(sizes);
|
|
}
|
|
|
|
template <typename T>
|
|
std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
|
|
std::vector<T> strides(sizes.size());
|
|
switch (sizes.size()) {
|
|
case 5:
|
|
strides[1] = 1;
|
|
strides[4] = sizes[1];
|
|
strides[3] = strides[4] * sizes[4];
|
|
strides[2] = strides[3] * sizes[3];
|
|
strides[0] = strides[2] * sizes[2];
|
|
return strides;
|
|
case 4:
|
|
strides[0] = 1;
|
|
strides[3] = sizes[0];
|
|
strides[2] = strides[3] * sizes[3];
|
|
strides[1] = strides[2] * sizes[2];
|
|
return strides;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "ChannelsLast3d doesn't support size ", sizes.size());
|
|
}
|
|
}
|
|
|
|
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
|
return get_channels_last_strides_3d<int64_t>(sizes);
|
|
}
|
|
|
|
// NOTE:
|
|
// Below are Helper functions for is_channels_last_strides_xd.
|
|
// 1. Please do not combine these helper functions, each helper function handles
|
|
// exactly one case of sizes + memory_format, by doing this, the strides indices
|
|
// will be a constant array and we can access it using constant index number,
|
|
// the compiler will fully unroll the loop on strides indices to gain a better
|
|
// performance.
|
|
// 2. No error check in helper function, caller ensures the correctness of the
|
|
// input
|
|
// 3. All helper functions have similar comments, only 1st helper function is
|
|
// commented here.
|
|
template <typename T>
|
|
inline bool is_channels_last_strides_2d_s4(
|
|
const ArrayRef<T> sizes,
|
|
const ArrayRef<T> strides) {
|
|
T min = 0;
|
|
// special case for trivial C dimension. default to NCHW
|
|
if (strides[1] == 0) {
|
|
return false;
|
|
}
|
|
// loop strides indices
|
|
for (auto& d : {1, 3, 2, 0}) {
|
|
if (sizes[d] == 0) {
|
|
return false;
|
|
}
|
|
if (strides[d] < min) {
|
|
return false;
|
|
}
|
|
// Fallback to NCHW as default layout for ambiguous cases
|
|
// This is the flaw of implicit memory_format from strides.
|
|
// N111 tensor with identical strides for size 1 dimension;
|
|
// Two cases could lead us here:
|
|
// a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
|
|
// b. N11W contiguous Tensor sliced on the W-dimension.
|
|
// ([N,1,1,1]@[W,W,W,W])
|
|
if (d == 0 && min == strides[1]) {
|
|
return false;
|
|
}
|
|
// This is necessary to:
|
|
// 1. distinguish the memory_format of N1H1;
|
|
// [H, 1, 1, 1] channels_last stride
|
|
// [H, H, 1, 1] contiguous stride
|
|
// 2. permutation of 1C1W:
|
|
// [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
|
|
// [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
|
|
min = strides[d];
|
|
if (sizes[d] > 1) {
|
|
min *= sizes[d];
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool is_channels_last_strides_3d_s5(
|
|
const ArrayRef<T> sizes,
|
|
const ArrayRef<T> strides) {
|
|
T min = 0;
|
|
if (strides[1] == 0) {
|
|
return false;
|
|
}
|
|
for (auto& d : {1, 4, 3, 2, 0}) {
|
|
if (sizes[d] == 0) {
|
|
return false;
|
|
}
|
|
if (strides[d] < min) {
|
|
return false;
|
|
}
|
|
if (d == 0 && min == strides[1]) {
|
|
return false;
|
|
}
|
|
min = strides[d];
|
|
if (sizes[d] > 1) {
|
|
min *= sizes[d];
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Note [Ambiguous is_channels_last_strides_xd]
|
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
// The flaw of carrying memory_format implicitly through strides is very hard
|
|
// to WAR properly. issue #24090
|
|
// Without the history of permutation, we can't infer the memory_format of a
|
|
// tensor from the snapshot of its size & stride
|
|
// e.g.
|
|
//
|
|
// 1. We can NOT specify the memory_format of N111 tensor through strides in a
|
|
// meaningful way;
|
|
//
|
|
// 2. Two path that ended up with identical size/stride
|
|
// N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
|
|
// NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
|
|
// So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
|
|
// the memory_format of the original tensor.
|
|
//
|
|
// Due to the limitations, our temporary WAR `is_channels_last_strides` does the
|
|
// best effort to infer whether the original memory_format of a tensor is
|
|
// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
|
|
// by their importance):
|
|
// 1. Ensure that normal shape manipulation does not accidentally change the
|
|
// MemoryFormat of an existing tensor.
|
|
// 2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
|
|
//
|
|
// The function does so via checking strides of the tensor, including strides of
|
|
// size-1 dimensions. Although conventionally PyTorch implies no restriction on
|
|
// trivial stride (stride for size-1 dimension).
|
|
//
|
|
// Note that this approach is a compromise. We did not solve the problem
|
|
// completely. Many cases we will not be able to infer the correct memory
|
|
// format.
|
|
// The implementation of `is_channels_last_strides` is to serve the objectives:
|
|
// MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
|
|
// conversion); Best effort to maintain the ChannelsLast flag.
|
|
//
|
|
// Due to the fact that this is not a bulletproof solution, through testing
|
|
// (aten/src/ATen/test/memory_format_test.cpp)
|
|
// a. we ensure that the common tasks are supported;
|
|
// a. we identify corner cases where the implementation compromises on.
|
|
//
|
|
// By the time accumulated permutation is enabled to replace implicit
|
|
// memory_format through strides, we should be updating our tests and fix the
|
|
// issues in our tests.
|
|
//
|
|
// We use Channels Last 2d as an example above.
|
|
// This is a general problem for all the is_channels_last_strides_xd
|
|
// implementation. Please check the helper functions
|
|
// (is_channels_last_strides_*d_s*) for more details.
|
|
|
|
template <typename T>
|
|
inline bool is_channels_last_strides_2d(
|
|
const ArrayRef<T> sizes,
|
|
const ArrayRef<T> strides) {
|
|
switch (sizes.size()) {
|
|
case 4:
|
|
return is_channels_last_strides_2d_s4(sizes, strides);
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
case 3:
|
|
// TODO dim == 3 case will be enabled once it is fully tested
|
|
return false;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool is_channels_last_strides_3d(
|
|
const ArrayRef<T> sizes,
|
|
const ArrayRef<T> strides) {
|
|
switch (sizes.size()) {
|
|
case 5:
|
|
return is_channels_last_strides_3d_s5(sizes, strides);
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
case 4:
|
|
// TODO dim == 4 case will be enabled once it is fully tested
|
|
return false;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
inline bool is_channels_last_strides_2d(
|
|
const IntArrayRef sizes,
|
|
const IntArrayRef strides) {
|
|
return is_channels_last_strides_2d<int64_t>(sizes, strides);
|
|
}
|
|
|
|
inline bool is_channels_last_strides_3d(
|
|
const IntArrayRef sizes,
|
|
const IntArrayRef strides) {
|
|
return is_channels_last_strides_3d<int64_t>(sizes, strides);
|
|
}
|
|
|
|
} // namespace c10
|