mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Generated by running the following from PyTorch root: ``` find . -regex ".*\.\(cpp\|h\|cu\|hpp\|cc\|cxx\)$" | grep -v "build/" | xargs -n 50 -P 4 perl -pi -e 's/c10::optional/std::optional/' ``` `c10::optional` is just an alias for `std::optional`. This removes usages of that alias in preparation for eliminating it entirely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126135 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi
81 lines
2.8 KiB
C++
81 lines
2.8 KiB
C++
#include <ATen/core/NestedIntSymNodeImpl.h>
|
|
#include <c10/core/SymNodeImpl.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace c10 {
|
|
|
|
namespace {
|
|
bool _eq(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
|
TORCH_INTERNAL_ASSERT(lhs->is_nested_int());
|
|
std::optional<int64_t> c = rhs->nested_int();
|
|
return (
|
|
c.has_value() && lhs->nested_int() == *c &&
|
|
lhs->nested_int_coeff() == rhs->nested_int_coeff());
|
|
}
|
|
bool _ge(const char* op, c10::SymNodeImpl* lhs, c10::SymNodeImpl* rhs) {
|
|
if (auto mb_si = lhs->nested_int()) {
|
|
if (auto mb_si2 = rhs->nested_int()) {
|
|
if (*mb_si == *mb_si2) {
|
|
return lhs->nested_int_coeff() >= rhs->nested_int_coeff();
|
|
}
|
|
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
|
|
}
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
if (rhs->constant_int() && *rhs->constant_int() <= 2) {
|
|
return true;
|
|
}
|
|
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
|
|
} else if (rhs->nested_int()) {
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
if (lhs->constant_int() && *lhs->constant_int() < 2) {
|
|
return false;
|
|
}
|
|
TORCH_CHECK(false, "nested int ", op, ": Relation is indeterminate");
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "expect at least one nested int");
|
|
}
|
|
} // namespace
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::eq(const c10::SymNode& other) {
|
|
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
|
_eq("eq", this, other.get())));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::ne(const c10::SymNode& other) {
|
|
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
|
!_eq("ne", this, other.get())));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::ge(const c10::SymNode& other) {
|
|
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
|
_ge("ge", this, other.get())));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::gt(const c10::SymNode& other) {
|
|
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
|
!_ge("gt", other.get(), this)));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::lt(const c10::SymNode& other) {
|
|
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
|
!_ge("lt", this, other.get())));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::le(const c10::SymNode& other) {
|
|
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(
|
|
_ge("le", other.get(), this)));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::mul(const c10::SymNode& other) {
|
|
TORCH_CHECK(!other->nested_int(), "nested int cannot be multiplied by nested int");
|
|
std::optional<int64_t> c = other->constant_int();
|
|
TORCH_CHECK(c.has_value());
|
|
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_ * *c));
|
|
}
|
|
|
|
c10::SymNode NestedIntSymNodeImpl::clone() {
|
|
return SymNode(c10::make_intrusive<NestedIntSymNodeImpl>(val_, coeff_));
|
|
}
|
|
|
|
} // namespace c10
|