Files
pytorch/c10/core/SymBool.h
cyy 1544c37520 [7/N] Fixes clang-tidy warnings in c10/{core,util}/*.h (#115495)
This PR continues to fix clang-tidy warnings for headers in c10/core and c10/util.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115495
Approved by: https://github.com/malfet
2023-12-19 02:14:30 +00:00

93 lines
2.2 KiB
C++

#pragma once
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/intrusive_ptr.h>
#include <cstdint>
#include <ostream>
#include <utility>
namespace c10 {
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b){};
SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_bool());
};
SymBool() : data_(false) {}
SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get();
}
SymNodeImpl* release() && {
return std::move(ptr_).release();
}
// Only valid if is_heap_allocated()
SymNode toSymNodeImpl() const;
// Guaranteed to return a SymNode, wrapping using base if necessary
SymNode wrap_node(const SymNode& base) const;
bool expect_bool() const {
c10::optional<bool> c = maybe_as_bool();
TORCH_CHECK(c.has_value());
return *c;
}
SymBool sym_and(const SymBool&) const;
SymBool sym_or(const SymBool&) const;
SymBool sym_not() const;
SymBool operator&(const SymBool& other) const {
return sym_and(other);
}
SymBool operator|(const SymBool& other) const {
return sym_or(other);
}
SymBool operator~() const {
return sym_not();
}
// Insert a guard for the bool to be its concrete value, and then return
// that value. Note that C++ comparison operations default to returning
// bool, so it's not so common to have to call this
bool guard_bool(const char* file, int64_t line) const;
bool expect_true(const char* file, int64_t line) const;
bool has_hint() const;
bool as_bool_unchecked() const {
return data_;
}
c10::optional<bool> maybe_as_bool() const {
if (!is_heap_allocated()) {
return c10::make_optional(data_);
}
return toSymNodeImplUnowned()->constant_bool();
}
bool is_heap_allocated() const {
return ptr_;
}
private:
// TODO: optimize to union
bool data_;
SymNode ptr_;
};
C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
#define TORCH_SYM_CHECK(cond, ...) \
TORCH_CHECK((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
#define TORCH_SYM_INTERNAL_ASSERT(cond, ...) \
TORCH_INTERNAL_ASSERT((cond).expect_true(__FILE__, __LINE__), __VA_ARGS__)
} // namespace c10