symbolic cpp channels_last_contiguous (#160402)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160402
Approved by: https://github.com/aorenste
This commit is contained in:
Laith Sakka
2025-09-05 07:26:05 -07:00
committed by PyTorch MergeBot
parent 70d36e047d
commit 79fcd5247a
2 changed files with 195 additions and 25 deletions

View File

@ -33,7 +33,8 @@ bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
}
// Return a SymBool with underlying symbolic expression that represents
// contiguity. Guaranteed not to add guards.
// contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions
// or symbolic True.
inline static c10::SymBool _compute_contiguous_sym(
ArrayRef<c10::SymInt> sizes,
ArrayRef<c10::SymInt> strides,
@ -76,6 +77,8 @@ inline static c10::SymBool _compute_contiguous_sym(
return true;
};
// We try to minimize creating large symbolic expressions when not needed to
// avoid symbolic evaluation perf issues.
if (is_contiguous_or_false()) {
return c10::SymBool(true);
}
@ -94,6 +97,9 @@ inline static c10::SymBool _compute_contiguous_sym(
return is_contiguous_cond.sym_or(is_empty);
}
// When T is SymInt this function may throw a data dependent error.
// _compute_channels_last_contiguous_2d_sym does not. Only use this function
// when inputs are hinted.
template <typename T>
bool _compute_channels_last_contiguous_2d(
ArrayRef<T> sizes,
@ -105,8 +111,8 @@ bool _compute_channels_last_contiguous_2d(
T expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
if (size_d != 1) {
if (strides[d] != expected) {
return false;
}
expected *= size_d;
@ -123,6 +129,65 @@ bool _compute_channels_last_contiguous_2d(
}
}
// Return a SymBool with underlying symbolic expression that represents
// contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions
// or symbolic True.
inline static c10::SymBool _compute_channels_last_contiguous_2d_sym(
ArrayRef<c10::SymInt> sizes,
ArrayRef<c10::SymInt> strides) {
switch (sizes.size()) {
case 4: {
// When this function return True, result always true. When it return
// False, result could be False or data dependent.
auto guard_or_false = [&]() {
c10::SymInt expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto& size_d = sizes[d];
// Not taking this branch could make this return False instead of True
// but not vice-versa. so its ok.
if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) {
continue;
}
// Taking this branch could make this return False instead of True
// but not vice-versa. so its ok.
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected))) {
return false;
}
expected *= size_d;
}
return true;
};
// We try to minimize creating large symbolic expressions when not needed
// to avoid symbolic evaluation perf issues.
if (guard_or_false()) {
return c10::SymBool(true);
}
// Result is either false, or data dependent.
c10::SymInt expected_stride = 1;
c10::SymBool cond = true;
for (auto& d : {1, 3, 2, 0}) {
const auto& size_d = sizes[d];
cond = cond.sym_and(
size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride)));
expected_stride *= size_d;
}
return cond;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return c10::SymBool(false);
default:
return c10::SymBool(false);
}
}
// When T is SymInt this function may throw a data dependent error.
// _compute_channels_last_contiguous_3d_sym does not. Only use this function
// when inputs are hinted.
template <typename T>
bool _compute_channels_last_contiguous_3d(
ArrayRef<T> sizes,
@ -134,8 +199,8 @@ bool _compute_channels_last_contiguous_3d(
T expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto& size_d = sizes[d];
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
if (size_d != 1) {
if (strides[d] != expected) {
return false;
}
expected *= size_d;
@ -152,6 +217,59 @@ bool _compute_channels_last_contiguous_3d(
}
}
inline static c10::SymBool _compute_channels_last_contiguous_3d_sym(
ArrayRef<c10::SymInt> sizes,
ArrayRef<c10::SymInt> strides) {
switch (sizes.size()) {
case 5: {
// When this function return True, result always true. When it return
// False, result could be False or data dependent.
auto guard_or_false = [&]() {
c10::SymInt expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto& size_d = sizes[d];
// Not taking this branch could make this return False instead of True
// but not vice-versa. so its ok.
if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) {
continue;
}
// Taking this branch could make this return False instead of True
// but not vice-versa. so its ok.
if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected))) {
return false;
}
expected *= size_d;
}
return true;
};
// We try to minimize creating large symbolic expressions when not needed
// to avoid symbolic evaluation perf issues.
if (guard_or_false()) {
return c10::SymBool(true);
}
// Result is either false, or data dependent.
c10::SymInt expected_stride = 1;
c10::SymBool cond = true;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto& size_d = sizes[d];
cond = cond.sym_and(
size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride)));
expected_stride *= size_d;
}
return cond;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return c10::SymBool(false);
default:
return c10::SymBool(false);
}
}
template <typename T>
bool _compute_non_overlapping_and_dense(
ArrayRef<T> sizes,

View File

@ -71,6 +71,27 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
return std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
std::move(base), std::move(size_nodes), std::move(stride_nodes));
}
namespace {
bool all_hinted(
const c10::SymIntArrayRef& sizes,
const c10::SymIntArrayRef& strides) {
auto all_hinted = true;
for (const auto& s : sizes) {
if (!s.has_hint()) {
return false;
}
}
if (all_hinted) {
for (const auto& s : strides) {
if (!s.has_hint()) {
return false;
}
}
}
return all_hinted;
}
} // namespace
// Special treatment because of numel
SymBool SymbolicShapeMeta::compute_contiguous() const {
@ -88,24 +109,7 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
return maybe_as_bool.value();
}
auto all_hinted = true;
for (const auto& s : sizes) {
if (!s.has_hint()) {
all_hinted = false;
break;
}
}
if (all_hinted) {
for (const auto& s : strides) {
if (!s.has_hint()) {
all_hinted = false;
break;
}
}
}
if (all_hinted) {
if (all_hinted(sizes, strides)) {
// We avoid going through the slow path if everything is hinted,
// because evaluating a large SymPy expression can be expensive.
// TODO exclude backed_size_oblivious from this path.
@ -115,6 +119,56 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
return result;
}
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_2d() const {
if (!strides_valid_) {
return false;
}
c10::SymIntArrayRef sizes(sizes_);
c10::SymIntArrayRef strides(strides_);
auto result = _compute_channels_last_contiguous_2d_sym(sizes, strides);
// If the result is already determined without guarding, just return it.
auto maybe_as_bool = result.maybe_as_bool();
if (maybe_as_bool.has_value()) {
return maybe_as_bool.value();
}
if (all_hinted(sizes, strides)) {
// We avoid going through the slow path if everything is hinted,
// because evaluating a large SymPy expression can be expensive.
// TODO exclude backed_size_oblivious from this path.
return _compute_channels_last_contiguous_2d<SymInt>(sizes_, strides_);
}
return result;
}
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d() const {
if (!strides_valid_) {
return false;
}
c10::SymIntArrayRef sizes(sizes_);
c10::SymIntArrayRef strides(strides_);
auto result = _compute_channels_last_contiguous_3d_sym(sizes, strides);
// If the result is already determined without guarding, just return it.
auto maybe_as_bool = result.maybe_as_bool();
if (maybe_as_bool.has_value()) {
return maybe_as_bool.value();
}
if (all_hinted(sizes, strides)) {
// We avoid going through the slow path if everything is hinted,
// because evaluating a large SymPy expression can be expensive.
// TODO exclude backed_size_oblivious from this path.
return _compute_channels_last_contiguous_3d<SymInt>(sizes_, strides_);
}
return result;
}
// The rest of them
#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \
SymBool SymbolicShapeMeta::name() const { \
@ -143,8 +197,6 @@ SymBool SymbolicShapeMeta::compute_contiguous() const {
}
// clang-format off
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d)
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d)