mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160402 Approved by: https://github.com/aorenste
407 lines
12 KiB
C++
407 lines
12 KiB
C++
#include <c10/core/Contiguity.h>
|
|
#include <c10/core/MemoryFormat.h>
|
|
#include <c10/core/SymInt.h>
|
|
#include <c10/core/SymIntArrayRef.h>
|
|
#include <c10/core/SymbolicShapeMeta.h>
|
|
|
|
namespace c10 {
|
|
|
|
SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other)
|
|
// Non-mutables can be accessed outside the mutex
|
|
: sizes_(other.sizes_),
|
|
strides_(other.strides_),
|
|
storage_offset_(other.storage_offset_),
|
|
strides_valid_(other.strides_valid_) {
|
|
std::scoped_lock lock(other.mutables_);
|
|
// These must be copied under lock, so ignore clang-tidy here!
|
|
// NOLINTBEGIN(cppcoreguidelines-prefer-member-initializer)
|
|
numel_ = other.numel_;
|
|
is_contiguous_ = other.is_contiguous_;
|
|
is_channels_last_contiguous_ = other.is_channels_last_contiguous_;
|
|
is_channels_last_3d_contiguous_ = other.is_channels_last_3d_contiguous_;
|
|
is_channels_last_ = other.is_channels_last_;
|
|
is_channels_last_3d_ = other.is_channels_last_3d_;
|
|
is_non_overlapping_and_dense_ = other.is_non_overlapping_and_dense_;
|
|
available_.store(other.available_.load());
|
|
// NOLINTEND(cppcoreguidelines-prefer-member-initializer)
|
|
}
|
|
|
|
// base, sizes, strides
|
|
static std::optional<
|
|
std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>>
|
|
normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
|
|
// Look for a SymNode to dispatch on
|
|
SymNode base;
|
|
bool all_hinted = true;
|
|
// NB: sizes/strides guaranteed to be positive, so only need
|
|
// is_heap_allocated
|
|
for (const auto& s : sizes) {
|
|
if (all_hinted && !s.has_hint()) {
|
|
all_hinted = false;
|
|
}
|
|
if (!base && s.is_heap_allocated()) {
|
|
base = s.toSymNode();
|
|
}
|
|
}
|
|
for (const auto& s : strides) {
|
|
if (all_hinted && !s.has_hint()) {
|
|
all_hinted = false;
|
|
}
|
|
if (!base && s.is_heap_allocated()) {
|
|
base = s.toSymNode();
|
|
}
|
|
}
|
|
if (!base || all_hinted) {
|
|
// Couldn't find. Tell the caller to do the normal computation
|
|
// Alternately, if everything is hinted, we want the normal computation
|
|
// too
|
|
return std::nullopt;
|
|
}
|
|
// Populate the SymNode array
|
|
std::vector<SymNode> size_nodes;
|
|
std::vector<SymNode> stride_nodes;
|
|
size_nodes.reserve(sizes.size());
|
|
stride_nodes.reserve(strides.size());
|
|
for (const auto& s : sizes) {
|
|
size_nodes.emplace_back(s.wrap_node(base));
|
|
}
|
|
for (const auto& s : strides) {
|
|
stride_nodes.emplace_back(s.wrap_node(base));
|
|
}
|
|
return std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
|
|
std::move(base), std::move(size_nodes), std::move(stride_nodes));
|
|
}
|
|
namespace {
|
|
bool all_hinted(
|
|
const c10::SymIntArrayRef& sizes,
|
|
const c10::SymIntArrayRef& strides) {
|
|
auto all_hinted = true;
|
|
for (const auto& s : sizes) {
|
|
if (!s.has_hint()) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (all_hinted) {
|
|
for (const auto& s : strides) {
|
|
if (!s.has_hint()) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return all_hinted;
|
|
}
|
|
} // namespace
|
|
|
|
// Special treatment because of numel
|
|
SymBool SymbolicShapeMeta::compute_contiguous() const {
|
|
if (!strides_valid_) {
|
|
return false;
|
|
}
|
|
c10::SymIntArrayRef sizes(sizes_);
|
|
c10::SymIntArrayRef strides(strides_);
|
|
|
|
auto result = _compute_contiguous_sym(sizes, strides, numel());
|
|
|
|
// If the result is already determined without guarding, just return it.
|
|
auto maybe_as_bool = result.maybe_as_bool();
|
|
if (maybe_as_bool.has_value()) {
|
|
return maybe_as_bool.value();
|
|
}
|
|
|
|
if (all_hinted(sizes, strides)) {
|
|
// We avoid going through the slow path if everything is hinted,
|
|
// because evaluating a large SymPy expression can be expensive.
|
|
// TODO exclude backed_size_oblivious from this path.
|
|
return _compute_contiguous<SymInt>(sizes_, strides_, numel());
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_2d() const {
|
|
if (!strides_valid_) {
|
|
return false;
|
|
}
|
|
c10::SymIntArrayRef sizes(sizes_);
|
|
c10::SymIntArrayRef strides(strides_);
|
|
|
|
auto result = _compute_channels_last_contiguous_2d_sym(sizes, strides);
|
|
|
|
// If the result is already determined without guarding, just return it.
|
|
auto maybe_as_bool = result.maybe_as_bool();
|
|
if (maybe_as_bool.has_value()) {
|
|
return maybe_as_bool.value();
|
|
}
|
|
|
|
if (all_hinted(sizes, strides)) {
|
|
// We avoid going through the slow path if everything is hinted,
|
|
// because evaluating a large SymPy expression can be expensive.
|
|
// TODO exclude backed_size_oblivious from this path.
|
|
return _compute_channels_last_contiguous_2d<SymInt>(sizes_, strides_);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d() const {
|
|
if (!strides_valid_) {
|
|
return false;
|
|
}
|
|
c10::SymIntArrayRef sizes(sizes_);
|
|
c10::SymIntArrayRef strides(strides_);
|
|
|
|
auto result = _compute_channels_last_contiguous_3d_sym(sizes, strides);
|
|
|
|
// If the result is already determined without guarding, just return it.
|
|
auto maybe_as_bool = result.maybe_as_bool();
|
|
if (maybe_as_bool.has_value()) {
|
|
return maybe_as_bool.value();
|
|
}
|
|
|
|
if (all_hinted(sizes, strides)) {
|
|
// We avoid going through the slow path if everything is hinted,
|
|
// because evaluating a large SymPy expression can be expensive.
|
|
// TODO exclude backed_size_oblivious from this path.
|
|
return _compute_channels_last_contiguous_3d<SymInt>(sizes_, strides_);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
// The rest of them
|
|
#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \
|
|
SymBool SymbolicShapeMeta::name() const { \
|
|
if (!strides_valid_) { \
|
|
return false; \
|
|
} \
|
|
c10::SymIntArrayRef sizes(sizes_); \
|
|
c10::SymIntArrayRef strides(strides_); \
|
|
return fallback(sizes, strides); \
|
|
}
|
|
|
|
#define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
|
|
SymBool SymbolicShapeMeta::name() const { \
|
|
if (!strides_valid_) { \
|
|
return false; \
|
|
} \
|
|
auto n = normalize_sym_sizes_strides(sizes_, strides_); \
|
|
if (n.has_value()) { \
|
|
auto [base, size_nodes, stride_nodes] = *n; \
|
|
return SymBool(base->nodeimpl(size_nodes, stride_nodes)); \
|
|
} else { \
|
|
c10::SymIntArrayRef sizes(sizes_); \
|
|
c10::SymIntArrayRef strides(strides_); \
|
|
return fallback(sizes, strides); \
|
|
} \
|
|
}
|
|
|
|
// clang-format off
|
|
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d)
|
|
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d)
|
|
|
|
DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense)
|
|
|
|
// clang-format on
|
|
|
|
#undef DEFINE_SYMBOOL_COMPUTE
|
|
|
|
// Glue compute
|
|
// NB: this logic very intentionally short circuits if possible. Without
|
|
// short circuiting, it causes
|
|
// python test/functorch/test_aotdispatch.py -k
|
|
// test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run
|
|
// very slowly.
|
|
|
|
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
|
|
init_is_contiguous();
|
|
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
|
return true;
|
|
}
|
|
init_is_channels_last_contiguous();
|
|
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
|
return true;
|
|
}
|
|
return is_contiguous() | is_channels_last_contiguous() |
|
|
compute_non_overlapping_and_dense();
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
|
|
init_is_channels_last_contiguous();
|
|
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
|
return false;
|
|
}
|
|
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
|
|
init_is_channels_last_3d_contiguous();
|
|
if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
|
return false;
|
|
}
|
|
return ~is_channels_last_3d_contiguous() &
|
|
compute_strides_like_channels_last_2d();
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
|
|
if (guard_or_false(is_channels_last(), __FILE__, __LINE__)) {
|
|
return false;
|
|
}
|
|
return ~is_channels_last() & compute_strides_like_channels_last_3d();
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
|
|
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
|
return true;
|
|
}
|
|
if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) {
|
|
return true;
|
|
}
|
|
if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
|
|
return true;
|
|
}
|
|
return is_contiguous() | is_channels_last_contiguous() |
|
|
is_channels_last_3d_contiguous() | compute_non_overlapping_and_dense();
|
|
}
|
|
|
|
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
|
|
if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) {
|
|
return true;
|
|
}
|
|
return is_contiguous() | compute_non_overlapping_and_dense();
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_numel(SymInt val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_numel()) {
|
|
return;
|
|
}
|
|
numel_ = std::move(val);
|
|
available_.fetch_or(numel_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_is_contiguous(SymBool val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_is_contiguous()) {
|
|
return;
|
|
}
|
|
is_contiguous_ = std::move(val);
|
|
available_.fetch_or(is_contiguous_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_is_channels_last_contiguous()) {
|
|
return;
|
|
}
|
|
is_channels_last_contiguous_ = std::move(val);
|
|
available_.fetch_or(is_channels_last_contiguous_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_is_channels_last_3d_contiguous()) {
|
|
return;
|
|
}
|
|
is_channels_last_3d_contiguous_ = std::move(val);
|
|
available_.fetch_or(is_channels_last_3d_contiguous_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_is_channels_last(SymBool val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_is_channels_last()) {
|
|
return;
|
|
}
|
|
is_channels_last_ = std::move(val);
|
|
available_.fetch_or(is_channels_last_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_is_channels_last_3d()) {
|
|
return;
|
|
}
|
|
is_channels_last_3d_ = std::move(val);
|
|
available_.fetch_or(is_channels_last_3d_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::set_is_non_overlapping_and_dense(SymBool val) const {
|
|
std::scoped_lock lock(mutables_);
|
|
if (has_is_non_overlapping_and_dense()) {
|
|
return;
|
|
}
|
|
is_non_overlapping_and_dense_ = std::move(val);
|
|
available_.fetch_or(is_non_overlapping_and_dense_avail);
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_numel() const {
|
|
set_numel(multiply_integers(sizes_));
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_is_contiguous() const {
|
|
set_is_contiguous(compute_contiguous());
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_is_channels_last_contiguous() const {
|
|
set_is_channels_last_contiguous([&] {
|
|
switch (dim()) {
|
|
case 5:
|
|
case 4: {
|
|
return compute_channels_last_contiguous_2d();
|
|
}
|
|
default:
|
|
return SymBool{false};
|
|
}
|
|
}());
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_is_channels_last_3d_contiguous() const {
|
|
set_is_channels_last_3d_contiguous([&] {
|
|
switch (dim()) {
|
|
case 5:
|
|
return compute_channels_last_contiguous_3d_dim5();
|
|
default:
|
|
return SymBool{false};
|
|
}
|
|
}());
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_is_channels_last() const {
|
|
set_is_channels_last([&] {
|
|
switch (dim()) {
|
|
case 5:
|
|
return compute_channels_last_2d_dim5();
|
|
case 4:
|
|
return compute_strides_like_channels_last_2d();
|
|
default:
|
|
return SymBool{false};
|
|
}
|
|
}());
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_is_channels_last_3d() const {
|
|
set_is_channels_last_3d([&] {
|
|
switch (dim()) {
|
|
case 5:
|
|
return compute_channels_last_3d_dim5();
|
|
default:
|
|
return SymBool{false};
|
|
}
|
|
}());
|
|
}
|
|
|
|
void SymbolicShapeMeta::init_is_non_overlapping_and_dense() const {
|
|
set_is_non_overlapping_and_dense([&] {
|
|
switch (dim()) {
|
|
case 5:
|
|
return compute_is_non_overlapping_and_dense_dim5();
|
|
case 4:
|
|
return compute_is_non_overlapping_and_dense_dim4();
|
|
default:
|
|
return compute_is_non_overlapping_and_dense_anydim();
|
|
}
|
|
}());
|
|
}
|
|
|
|
} // namespace c10
|