mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
symbolic cpp channels_last_contiguous (#160402)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160402 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
70d36e047d
commit
79fcd5247a
@ -33,7 +33,8 @@ bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
|
||||
}
|
||||
|
||||
// Return a SymBool with underlying symbolic expression that represents
|
||||
// contiguity. Guaranteed not to add guards.
|
||||
// contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions
|
||||
// or symbolic True.
|
||||
inline static c10::SymBool _compute_contiguous_sym(
|
||||
ArrayRef<c10::SymInt> sizes,
|
||||
ArrayRef<c10::SymInt> strides,
|
||||
@ -76,6 +77,8 @@ inline static c10::SymBool _compute_contiguous_sym(
|
||||
return true;
|
||||
};
|
||||
|
||||
// We try to minimize creating large symbolic expressions when not needed to
|
||||
// avoid symbolic evaluation perf issues.
|
||||
if (is_contiguous_or_false()) {
|
||||
return c10::SymBool(true);
|
||||
}
|
||||
@ -94,6 +97,9 @@ inline static c10::SymBool _compute_contiguous_sym(
|
||||
return is_contiguous_cond.sym_or(is_empty);
|
||||
}
|
||||
|
||||
// When T is SymInt this function may throw a data dependent error.
|
||||
// _compute_channels_last_contiguous_2d_sym does not. Only use this function
|
||||
// when inputs are hinted.
|
||||
template <typename T>
|
||||
bool _compute_channels_last_contiguous_2d(
|
||||
ArrayRef<T> sizes,
|
||||
@ -105,8 +111,8 @@ bool _compute_channels_last_contiguous_2d(
|
||||
T expected = 1;
|
||||
for (auto& d : {1, 3, 2, 0}) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
|
||||
if (size_d != 1) {
|
||||
if (strides[d] != expected) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
@ -123,6 +129,65 @@ bool _compute_channels_last_contiguous_2d(
|
||||
}
|
||||
}
|
||||
|
||||
// Return a SymBool with underlying symbolic expression that represents
|
||||
// contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions
|
||||
// or symbolic True.
|
||||
inline static c10::SymBool _compute_channels_last_contiguous_2d_sym(
|
||||
ArrayRef<c10::SymInt> sizes,
|
||||
ArrayRef<c10::SymInt> strides) {
|
||||
switch (sizes.size()) {
|
||||
case 4: {
|
||||
// When this function return True, result always true. When it return
|
||||
// False, result could be False or data dependent.
|
||||
auto guard_or_false = [&]() {
|
||||
c10::SymInt expected = 1;
|
||||
for (auto& d : {1, 3, 2, 0}) {
|
||||
const auto& size_d = sizes[d];
|
||||
// Not taking this branch could make this return False instead of True
|
||||
// but not vice-versa. so its ok.
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) {
|
||||
continue;
|
||||
}
|
||||
// Taking this branch could make this return False instead of True
|
||||
// but not vice-versa. so its ok.
|
||||
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected))) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// We try to minimize creating large symbolic expressions when not needed
|
||||
// to avoid symbolic evaluation perf issues.
|
||||
if (guard_or_false()) {
|
||||
return c10::SymBool(true);
|
||||
}
|
||||
|
||||
// Result is either false, or data dependent.
|
||||
c10::SymInt expected_stride = 1;
|
||||
c10::SymBool cond = true;
|
||||
|
||||
for (auto& d : {1, 3, 2, 0}) {
|
||||
const auto& size_d = sizes[d];
|
||||
cond = cond.sym_and(
|
||||
size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride)));
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return cond;
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
case 3:
|
||||
// TODO dim == 3 case will be enabled once it is fully tested
|
||||
return c10::SymBool(false);
|
||||
default:
|
||||
return c10::SymBool(false);
|
||||
}
|
||||
}
|
||||
|
||||
// When T is SymInt this function may throw a data dependent error.
|
||||
// _compute_channels_last_contiguous_3d_sym does not. Only use this function
|
||||
// when inputs are hinted.
|
||||
template <typename T>
|
||||
bool _compute_channels_last_contiguous_3d(
|
||||
ArrayRef<T> sizes,
|
||||
@ -134,8 +199,8 @@ bool _compute_channels_last_contiguous_3d(
|
||||
T expected = 1;
|
||||
for (auto& d : {1, 4, 3, 2, 0}) {
|
||||
const auto& size_d = sizes[d];
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
|
||||
if (size_d != 1) {
|
||||
if (strides[d] != expected) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
@ -152,6 +217,59 @@ bool _compute_channels_last_contiguous_3d(
|
||||
}
|
||||
}
|
||||
|
||||
inline static c10::SymBool _compute_channels_last_contiguous_3d_sym(
|
||||
ArrayRef<c10::SymInt> sizes,
|
||||
ArrayRef<c10::SymInt> strides) {
|
||||
switch (sizes.size()) {
|
||||
case 5: {
|
||||
// When this function return True, result always true. When it return
|
||||
// False, result could be False or data dependent.
|
||||
auto guard_or_false = [&]() {
|
||||
c10::SymInt expected = 1;
|
||||
for (auto& d : {1, 4, 3, 2, 0}) {
|
||||
const auto& size_d = sizes[d];
|
||||
// Not taking this branch could make this return False instead of True
|
||||
// but not vice-versa. so its ok.
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) {
|
||||
continue;
|
||||
}
|
||||
// Taking this branch could make this return False instead of True
|
||||
// but not vice-versa. so its ok.
|
||||
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected))) {
|
||||
return false;
|
||||
}
|
||||
expected *= size_d;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// We try to minimize creating large symbolic expressions when not needed
|
||||
// to avoid symbolic evaluation perf issues.
|
||||
if (guard_or_false()) {
|
||||
return c10::SymBool(true);
|
||||
}
|
||||
|
||||
// Result is either false, or data dependent.
|
||||
c10::SymInt expected_stride = 1;
|
||||
c10::SymBool cond = true;
|
||||
|
||||
for (auto& d : {1, 4, 3, 2, 0}) {
|
||||
const auto& size_d = sizes[d];
|
||||
cond = cond.sym_and(
|
||||
size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride)));
|
||||
expected_stride *= size_d;
|
||||
}
|
||||
return cond;
|
||||
}
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
case 4:
|
||||
// TODO dim == 4 case will be enabled once it is fully tested
|
||||
return c10::SymBool(false);
|
||||
default:
|
||||
return c10::SymBool(false);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool _compute_non_overlapping_and_dense(
|
||||
ArrayRef<T> sizes,
|
||||
|
@ -71,6 +71,27 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
|
||||
return std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
|
||||
std::move(base), std::move(size_nodes), std::move(stride_nodes));
|
||||
}
|
||||
namespace {
|
||||
bool all_hinted(
|
||||
const c10::SymIntArrayRef& sizes,
|
||||
const c10::SymIntArrayRef& strides) {
|
||||
auto all_hinted = true;
|
||||
for (const auto& s : sizes) {
|
||||
if (!s.has_hint()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (all_hinted) {
|
||||
for (const auto& s : strides) {
|
||||
if (!s.has_hint()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return all_hinted;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Special treatment because of numel
|
||||
SymBool SymbolicShapeMeta::compute_contiguous() const {
|
||||
@ -88,24 +109,7 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
|
||||
return maybe_as_bool.value();
|
||||
}
|
||||
|
||||
auto all_hinted = true;
|
||||
for (const auto& s : sizes) {
|
||||
if (!s.has_hint()) {
|
||||
all_hinted = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (all_hinted) {
|
||||
for (const auto& s : strides) {
|
||||
if (!s.has_hint()) {
|
||||
all_hinted = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (all_hinted) {
|
||||
if (all_hinted(sizes, strides)) {
|
||||
// We avoid going through the slow path if everything is hinted,
|
||||
// because evaluating a large SymPy expression can be expensive.
|
||||
// TODO exclude backed_size_oblivious from this path.
|
||||
@ -115,6 +119,56 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_2d() const {
|
||||
if (!strides_valid_) {
|
||||
return false;
|
||||
}
|
||||
c10::SymIntArrayRef sizes(sizes_);
|
||||
c10::SymIntArrayRef strides(strides_);
|
||||
|
||||
auto result = _compute_channels_last_contiguous_2d_sym(sizes, strides);
|
||||
|
||||
// If the result is already determined without guarding, just return it.
|
||||
auto maybe_as_bool = result.maybe_as_bool();
|
||||
if (maybe_as_bool.has_value()) {
|
||||
return maybe_as_bool.value();
|
||||
}
|
||||
|
||||
if (all_hinted(sizes, strides)) {
|
||||
// We avoid going through the slow path if everything is hinted,
|
||||
// because evaluating a large SymPy expression can be expensive.
|
||||
// TODO exclude backed_size_oblivious from this path.
|
||||
return _compute_channels_last_contiguous_2d<SymInt>(sizes_, strides_);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d() const {
|
||||
if (!strides_valid_) {
|
||||
return false;
|
||||
}
|
||||
c10::SymIntArrayRef sizes(sizes_);
|
||||
c10::SymIntArrayRef strides(strides_);
|
||||
|
||||
auto result = _compute_channels_last_contiguous_3d_sym(sizes, strides);
|
||||
|
||||
// If the result is already determined without guarding, just return it.
|
||||
auto maybe_as_bool = result.maybe_as_bool();
|
||||
if (maybe_as_bool.has_value()) {
|
||||
return maybe_as_bool.value();
|
||||
}
|
||||
|
||||
if (all_hinted(sizes, strides)) {
|
||||
// We avoid going through the slow path if everything is hinted,
|
||||
// because evaluating a large SymPy expression can be expensive.
|
||||
// TODO exclude backed_size_oblivious from this path.
|
||||
return _compute_channels_last_contiguous_3d<SymInt>(sizes_, strides_);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// The rest of them
|
||||
#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \
|
||||
SymBool SymbolicShapeMeta::name() const { \
|
||||
@ -143,8 +197,6 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
|
||||
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
|
||||
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d)
|
||||
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d)
|
||||
|
||||
|
Reference in New Issue
Block a user