Files
pytorch/c10/core/SymBool.cpp
Laith Sakka 872d1daec2 Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-06 01:04:19 +00:00

129 lines
3.8 KiB
C++

#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
namespace c10 {
SymNode SymBool::toSymNodeImpl() const {
TORCH_CHECK(is_heap_allocated());
return SymNode::reclaim_copy(toSymNodeImplUnowned());
}
SymNode SymBool::wrap_node(const SymNode& base) const {
if (auto ma = maybe_as_bool()) {
return base->wrap_bool(*ma);
} else {
return toSymNodeImpl();
}
}
#define DEFINE_BINARY(API, OP, METHOD, RET) \
RET SymBool::API(const SymBool& sci) const { \
if (auto ma = maybe_as_bool()) { \
if (auto mb = sci.maybe_as_bool()) { \
return RET(OP(*ma, *mb)); \
} else { \
auto b = sci.toSymNodeImpl(); \
return RET(b->wrap_bool(*ma)->METHOD(b)); \
} \
} else { \
if (auto mb = sci.maybe_as_bool()) { \
auto a = toSymNodeImplUnowned(); \
return RET(a->METHOD(a->wrap_bool(*mb))); \
} else { \
return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNodeImpl())); \
} \
} \
}
// clang-format off
DEFINE_BINARY(sym_and, std::logical_and<>(), sym_and, SymBool)
DEFINE_BINARY(sym_or, std::logical_or<>(), sym_or, SymBool)
// clang-format on
SymBool SymBool::sym_not() const {
if (auto ma = maybe_as_bool()) {
return SymBool(!*ma);
}
return SymBool(toSymNodeImpl()->sym_not());
}
std::ostream& operator<<(std::ostream& os, const SymBool& s) {
if (auto ma = s.maybe_as_bool()) {
os << *ma;
} else {
os << s.toSymNodeImpl()->str();
}
return os;
}
bool SymBool::guard_bool(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->guard_bool(file, line);
}
bool SymBool::guard_size_oblivious(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->guard_size_oblivious(file, line);
}
bool SymBool::guard_or_false(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->guard_or_false(file, line);
}
bool SymBool::statically_known_true(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->statically_known_true(file, line);
}
bool SymBool::guard_or_true(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->guard_or_true(file, line);
}
bool SymBool::expect_true(const char* file, int64_t line) const {
if (auto ma = maybe_as_bool()) {
return *ma;
}
SymNode a = toSymNodeImpl();
return a->expect_true(file, line);
}
bool SymBool::has_hint() const {
if (maybe_as_bool()) {
return true;
}
return toSymNodeImpl()->has_hint();
}
SymInt SymBool::toSymInt() const {
// If concrete bool, return concrete SymInt
if (auto ma = maybe_as_bool()) {
return SymInt(*ma ? 1 : 0);
}
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
auto node = toSymNodeImpl();
auto one_node = node->wrap_int(1);
auto zero_node = node->wrap_int(0);
return SymInt(node->sym_ite(one_node, zero_node));
}
} // namespace c10