mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Follows #132010 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132122 Approved by: https://github.com/Skylion007
371 lines
11 KiB
C++
371 lines
11 KiB
C++
#include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
|
#include <torch/csrc/jit/tensorexpr/stmt.h>
|
|
|
|
#include <iostream>
|
|
|
|
namespace torch::jit::tensorexpr::analysis {
|
|
|
|
// Returns true if the given expression is guaranteed to be positive.
|
|
static bool mustBePositive(const ExprPtr& e) {
|
|
if (e->isConstant()) {
|
|
int e_val = immediateAs<int>(e);
|
|
return e_val > 0;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Returns true if the given expression is guaranteed to be negative.
|
|
static bool mustBeNegative(const ExprPtr& e) {
|
|
if (e->isConstant()) {
|
|
int e_val = immediateAs<int>(e);
|
|
return e_val < 0;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Returns true if the given expression is guaranteed to be zero.
|
|
static bool mustBeZero(const ExprPtr& e) {
|
|
if (e->isConstant()) {
|
|
int e_val = immediateAs<int>(e);
|
|
return e_val == 0;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void Bound::print() const {
|
|
std::cout << "(" << *start << ", " << *end << ")";
|
|
}
|
|
|
|
bool Bound::equals(const Bound& other) const {
|
|
return exprEquals(start, other.start) && exprEquals(end, other.end);
|
|
}
|
|
|
|
bool Bound::operator==(const Bound& other) const {
|
|
if (equals(other)) {
|
|
auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, end));
|
|
return mustBeZero(ret_expr);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool Bound::operator!=(const Bound& other) const {
|
|
return (*this < other) || (*this > other);
|
|
}
|
|
|
|
bool Bound::operator>=(const Bound& other) const {
|
|
if (*this == other) {
|
|
return true;
|
|
}
|
|
auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, other.end));
|
|
return mustBePositive(ret_expr) || mustBeZero(ret_expr);
|
|
}
|
|
|
|
bool Bound::operator>(const Bound& other) const {
|
|
auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, other.end));
|
|
return mustBePositive(ret_expr);
|
|
}
|
|
|
|
bool Bound::operator<=(const Bound& other) const {
|
|
if (*this == other) {
|
|
return true;
|
|
}
|
|
auto ret_expr = IRSimplifier::simplify(alloc<Sub>(end, other.start));
|
|
return mustBeNegative(ret_expr) || mustBeZero(ret_expr);
|
|
}
|
|
|
|
bool Bound::operator<(const Bound& other) const {
|
|
auto ret_expr = IRSimplifier::simplify(alloc<Sub>(end, other.start));
|
|
return mustBeNegative(ret_expr);
|
|
}
|
|
|
|
OverlapKind boundOverlap(const Bound& a, const Bound& b) {
|
|
// If they're equal they're equal.
|
|
bool startEqual = exprEquals(a.start, b.start);
|
|
bool endEqual = exprEquals(a.end, b.end);
|
|
if (startEqual && endEqual) {
|
|
return OverlapKind::ContainedOrEqual;
|
|
}
|
|
|
|
// We have to figure out if the bounds fall under the following 2 cases:
|
|
// 1. a is before b
|
|
// a.start ... a.end ... b.start ... b.end
|
|
// 2. b is before a
|
|
// b.start ... b.end ... a.start ... a.end
|
|
//
|
|
// So, we compute "a.start - b.end" and "b.start - a.end". If even one of
|
|
// those is positive, then it is guaranteed that the bounds do not overlap.
|
|
//
|
|
// If the diff is a constant, then we can directly check if the constant is
|
|
// positive. If the diff is not a constant, then it will be made of
|
|
// variables that correspond to the bounds of buffers involved. These buffer
|
|
// bounds can never be negative. So, we check if the given expression is
|
|
// guaranteed to be positive under the assumption that the variables involved
|
|
// are never negative.
|
|
|
|
ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(a.start, b.end));
|
|
ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.end));
|
|
|
|
if (mustBePositive(lowDiff)) {
|
|
return OverlapKind::NoOverlap;
|
|
}
|
|
if (mustBePositive(highDiff)) {
|
|
return OverlapKind::NoOverlap;
|
|
}
|
|
|
|
ExprPtr diff_start = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
|
|
ExprPtr diff_end = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
|
|
|
|
// If one side fully encloses the other, they're adjacent.
|
|
if (diff_start->isConstant() && diff_end->isConstant()) {
|
|
int start = immediateAs<int>(diff_start);
|
|
int end = immediateAs<int>(diff_end);
|
|
// If diff_start and diff_end have different signs they are enclosing.
|
|
if (start <= 0 && end >= 0) {
|
|
return OverlapKind::ContainedOrEqual;
|
|
}
|
|
|
|
if (start >= 0 && end <= 0) {
|
|
return OverlapKind::Contains;
|
|
}
|
|
}
|
|
|
|
// We can't be sure there's no overlap so the conservative answer is
|
|
// partial.
|
|
return OverlapKind::PartialOverlap;
|
|
}
|
|
|
|
CmpEvalResult TORCH_API compareBound(
|
|
const Bound& a,
|
|
const Bound& b,
|
|
const CompareSelectOperation& cmp_op) {
|
|
switch (cmp_op) {
|
|
case CompareSelectOperation::kGT:
|
|
return (a > b)
|
|
? CmpEvalResult::True
|
|
: (a <= b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
|
|
case CompareSelectOperation::kGE:
|
|
return (a >= b)
|
|
? CmpEvalResult::True
|
|
: (a < b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
|
|
case CompareSelectOperation::kLT:
|
|
return (a < b)
|
|
? CmpEvalResult::True
|
|
: (a >= b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
|
|
case CompareSelectOperation::kLE:
|
|
return (a <= b)
|
|
? CmpEvalResult::True
|
|
: (a > b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
|
|
case CompareSelectOperation::kNE:
|
|
return (a != b)
|
|
? CmpEvalResult::True
|
|
: (a == b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(cmp_op == CompareSelectOperation::kEQ)
|
|
return (a == b)
|
|
? CmpEvalResult::True
|
|
: (a != b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
|
|
}
|
|
}
|
|
|
|
bool indexBoundsEquals(const IndexBounds& A, const IndexBounds& B) {
|
|
if (A.size() != B.size()) {
|
|
return false;
|
|
}
|
|
|
|
for (size_t i = 0; i != A.size(); ++i) {
|
|
if (!A[i].equals(B[i])) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
Bound flattenBounds(const IndexBounds& a) {
|
|
if (a.empty()) {
|
|
return Bound();
|
|
}
|
|
Bound ret = a[0];
|
|
|
|
for (size_t i = 1; i < a.size(); ++i) {
|
|
ret.start = alloc<Mul>(ret.start, a[i].start);
|
|
ret.end = alloc<Mul>(ret.end, a[i].end);
|
|
}
|
|
|
|
ret.start = IRSimplifier::simplify(ret.start);
|
|
ret.end = IRSimplifier::simplify(ret.end);
|
|
return ret;
|
|
}
|
|
|
|
OverlapKind overlaps(const IndexBounds& a, const IndexBounds& b) {
|
|
if (a.empty() && b.empty()) {
|
|
return OverlapKind::ContainedOrEqual;
|
|
}
|
|
|
|
// All accesses to a buf must have the same dimensionality.
|
|
|
|
if (a.size() != b.size()) {
|
|
return boundOverlap(flattenBounds(a), flattenBounds(b));
|
|
}
|
|
TORCH_INTERNAL_ASSERT(a.size() == b.size());
|
|
|
|
OverlapKind overlap = boundOverlap(a[0], b[0]);
|
|
for (size_t i = 1; i < a.size(); ++i) {
|
|
OverlapKind bOverlap = boundOverlap(a[i], b[i]);
|
|
if (bOverlap == OverlapKind::NoOverlap) {
|
|
return OverlapKind::NoOverlap;
|
|
}
|
|
|
|
if (overlap == OverlapKind::ContainedOrEqual &&
|
|
bOverlap == OverlapKind::Contains) {
|
|
overlap = OverlapKind::Contains;
|
|
}
|
|
|
|
if (overlap == OverlapKind::Contains &&
|
|
bOverlap == OverlapKind::ContainedOrEqual) {
|
|
continue;
|
|
}
|
|
|
|
if (bOverlap != overlap) {
|
|
overlap = OverlapKind::PartialOverlap;
|
|
break;
|
|
}
|
|
}
|
|
|
|
return overlap;
|
|
}
|
|
|
|
std::vector<Bound> subtractBound(const Bound& a, const Bound& b) {
|
|
OverlapKind overlap = boundOverlap(a, b);
|
|
if (overlap == OverlapKind::NoOverlap) {
|
|
return {a};
|
|
}
|
|
if (overlap == OverlapKind::ContainedOrEqual) {
|
|
return {};
|
|
}
|
|
|
|
// The bounds must overlap.
|
|
std::vector<Bound> res;
|
|
|
|
if (a.start->isConstant() != b.start->isConstant() ||
|
|
a.end->isConstant() != b.end->isConstant()) {
|
|
return {a};
|
|
}
|
|
|
|
ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
|
|
ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
|
|
|
|
// If the diff has only a single var, we can try to guess sign.
|
|
if (!lowDiff->isConstant()) {
|
|
auto vars = VarFinder::find(lowDiff);
|
|
if (vars.size() == 1) {
|
|
lowDiff = IRSimplifier::simplify(alloc<Sub>(
|
|
SubstituteInClone(b.start, {{*vars.begin(), immLike(b.start, 1)}}),
|
|
SubstituteInClone(a.start, {{*vars.begin(), immLike(a.start, 1)}})));
|
|
}
|
|
}
|
|
|
|
if (!highDiff->isConstant()) {
|
|
auto vars = VarFinder::find(highDiff);
|
|
if (vars.size() == 1) {
|
|
highDiff = IRSimplifier::simplify(alloc<Sub>(
|
|
SubstituteInClone(b.end, {{*vars.begin(), immLike(b.end, 1)}}),
|
|
SubstituteInClone(a.end, {{*vars.begin(), immLike(a.end, 1)}})));
|
|
}
|
|
}
|
|
|
|
bool hasHead = lowDiff->isConstant() && immediateAs<int>(lowDiff) > 0;
|
|
bool hasTail = highDiff->isConstant() && immediateAs<int>(highDiff) < 0;
|
|
|
|
bool constantExtents = lowDiff->isConstant() && highDiff->isConstant();
|
|
|
|
if (!constantExtents) {
|
|
// If we can't infer the bound lengths, there's no way to create a safe
|
|
// subset. Just bail out.
|
|
return {a};
|
|
}
|
|
|
|
if (hasHead) {
|
|
res.emplace_back(
|
|
a.start,
|
|
IRSimplifier::simplify(alloc<Sub>(b.start, immLike(b.start, 1))));
|
|
}
|
|
|
|
if (hasTail) {
|
|
ExprPtr tailStart =
|
|
IRSimplifier::simplify(alloc<Add>(b.end, immLike(b.end, 1)));
|
|
res.emplace_back(tailStart, a.end);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
std::vector<IndexBounds> subtractIndicesBounds(
|
|
const IndexBounds& A,
|
|
const IndexBounds& B,
|
|
OverlapKind overlap) {
|
|
if (overlap == OverlapKind::NoOverlap) {
|
|
return {A};
|
|
}
|
|
|
|
if (overlap == OverlapKind::ContainedOrEqual) {
|
|
return {};
|
|
}
|
|
// All accesses to a buf must have the same dimensionality.
|
|
TORCH_INTERNAL_ASSERT(A.size() == B.size(), buildErrorMessage());
|
|
|
|
// Each dimension can be sliced into multiple bound segments.
|
|
std::vector<IndexBounds> boundSlices;
|
|
std::vector<Bound> remainingOuterBounds;
|
|
|
|
for (size_t i = 0; i < A.size(); ++i) {
|
|
auto slices = subtractBound(A[i], B[i]);
|
|
|
|
Bound remaining = A[i];
|
|
|
|
for (const auto& slice : slices) {
|
|
IndexBounds newRegion;
|
|
newRegion.reserve(A.size());
|
|
TORCH_INTERNAL_ASSERT(
|
|
remainingOuterBounds.size() == i, buildErrorMessage());
|
|
|
|
for (size_t j = 0; j < i; ++j) {
|
|
newRegion.push_back(remainingOuterBounds[j]);
|
|
}
|
|
newRegion.push_back(slice);
|
|
for (size_t j = i + 1; j < A.size(); ++j) {
|
|
newRegion.push_back(A[j]);
|
|
}
|
|
|
|
boundSlices.push_back(newRegion);
|
|
|
|
if (slice.equals(A[i])) {
|
|
remaining = A[i];
|
|
} else {
|
|
auto remainingSlices = subtractBound(remaining, slice);
|
|
// In some cases, we might end up with empty remainingSlices due to the
|
|
// optimization done in subtraction while handling diff expressions
|
|
// that have a single variable in `subtractBound()`.
|
|
if (!remainingSlices.empty()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
remainingSlices.size() == 1, buildErrorMessage());
|
|
remaining = remainingSlices[0];
|
|
}
|
|
}
|
|
}
|
|
|
|
remainingOuterBounds.push_back(remaining);
|
|
}
|
|
|
|
return boundSlices;
|
|
}
|
|
|
|
std::vector<IndexBounds> TORCH_API
|
|
subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B) {
|
|
return subtractIndicesBounds(A, B, overlaps(A, B));
|
|
}
|
|
|
|
} // namespace torch::jit::tensorexpr::analysis
|