mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add inline fast paths for SymInt operators (#161586)
If SymInt::maybe_as_int() returns non-empty, then we get an inline fast path. The philosophy here (as with the previous PR) is to preserve performance in the "plain old ints" case. Observed time spent in SymInt functions in computeStorageNBytes to drop (and not cost shift elsewhere in the function) after this change, profiling detach() using code similar to the benchmark from #160580 and Linux perf. Differential Revision: [D81530107](https://our.internmc.facebook.com/intern/diff/D81530107) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161586 Approved by: https://github.com/ezyang ghstack dependencies: #161466
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa1514acf1
commit
b0a3e58dd7
@ -53,12 +53,11 @@ bool SymInt::has_hint() const {
|
||||
#define DEFINE_BINARY(API, OP, METHOD, RET) \
|
||||
RET SymInt::API(const SymInt& sci) const { \
|
||||
if (auto ma = maybe_as_int()) { \
|
||||
if (auto mb = sci.maybe_as_int()) { \
|
||||
return RET(OP(*ma, *mb)); \
|
||||
} else { \
|
||||
auto b = sci.toSymNode(); \
|
||||
return RET(b->wrap_int(*ma)->METHOD(b)); \
|
||||
} \
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( \
|
||||
!sci.maybe_as_int(), \
|
||||
"should have hit fast path in the header in this case."); \
|
||||
auto b = sci.toSymNode(); \
|
||||
return RET(b->wrap_int(*ma)->METHOD(b)); \
|
||||
} else { \
|
||||
if (auto mb = sci.maybe_as_int()) { \
|
||||
auto a = toSymNodeImplUnowned(); \
|
||||
@ -69,19 +68,19 @@ bool SymInt::has_hint() const {
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_BINARY(operator+, std::plus<>(), add, SymInt)
|
||||
DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt)
|
||||
DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt)
|
||||
DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt)
|
||||
DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt)
|
||||
DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool)
|
||||
DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool)
|
||||
DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool)
|
||||
DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool)
|
||||
DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool)
|
||||
DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool)
|
||||
DEFINE_BINARY(min, std::min, sym_min, SymInt)
|
||||
DEFINE_BINARY(max, std::max, sym_max, SymInt)
|
||||
DEFINE_BINARY(operator_add_slow_path, std::plus<>(), add, SymInt)
|
||||
DEFINE_BINARY(operator_sub_slow_path, std::minus<>(), sub, SymInt)
|
||||
DEFINE_BINARY(operator_mul_slow_path, std::multiplies<>(), mul, SymInt)
|
||||
DEFINE_BINARY(operator_div_slow_path, std::divides<>(), floordiv, SymInt)
|
||||
DEFINE_BINARY(operator_mod_slow_path, std::modulus<>(), mod, SymInt)
|
||||
DEFINE_BINARY(sym_eq_slow_path, std::equal_to<>(), eq, SymBool)
|
||||
DEFINE_BINARY(sym_ne_slow_path, std::not_equal_to<>(), ne, SymBool)
|
||||
DEFINE_BINARY(sym_lt_slow_path, std::less<>(), lt, SymBool)
|
||||
DEFINE_BINARY(sym_le_slow_path, std::less_equal<>(), le, SymBool)
|
||||
DEFINE_BINARY(sym_gt_slow_path, std::greater<>(), gt, SymBool)
|
||||
DEFINE_BINARY(sym_ge_slow_path, std::greater_equal<>(), ge, SymBool)
|
||||
DEFINE_BINARY(min_slow_path, std::min, sym_min, SymInt)
|
||||
DEFINE_BINARY(max_slow_path, std::max, sym_max, SymInt)
|
||||
|
||||
SymInt::operator SymFloat() const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
@ -161,15 +160,15 @@ SymInt operator-(const SymInt& s) {
|
||||
}
|
||||
}
|
||||
|
||||
void SymInt::operator*=(const SymInt& sci) {
|
||||
void SymInt::operator_imul_slow_path(const SymInt& sci) {
|
||||
*this = *this * sci;
|
||||
}
|
||||
|
||||
void SymInt::operator/=(const SymInt& sci) {
|
||||
void SymInt::operator_idiv_slow_path(const SymInt& sci) {
|
||||
*this = *this / sci;
|
||||
}
|
||||
|
||||
void SymInt::operator+=(const SymInt& sci) {
|
||||
void SymInt::operator_iadd_slow_path(const SymInt& sci) {
|
||||
*this = *this + sci;
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
@ -177,23 +178,136 @@ class C10_API SymInt {
|
||||
#endif
|
||||
}
|
||||
|
||||
SymInt operator+(const SymInt& sci) const;
|
||||
SymInt operator-(const SymInt& sci) const;
|
||||
SymInt operator*(const SymInt& sci) const;
|
||||
SymInt operator/(const SymInt& sci) const;
|
||||
SymInt operator%(const SymInt& sci) const;
|
||||
void operator*=(const SymInt& sci);
|
||||
void operator+=(const SymInt& sci);
|
||||
void operator/=(const SymInt& sci);
|
||||
SymInt operator+(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(*ma + *mb);
|
||||
}
|
||||
}
|
||||
return operator_add_slow_path(sci);
|
||||
}
|
||||
|
||||
SymInt operator-(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(*ma - *mb);
|
||||
}
|
||||
}
|
||||
return operator_sub_slow_path(sci);
|
||||
}
|
||||
|
||||
SymInt operator*(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(*ma * *mb);
|
||||
}
|
||||
}
|
||||
return operator_mul_slow_path(sci);
|
||||
}
|
||||
|
||||
SymInt operator/(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(*ma / *mb);
|
||||
}
|
||||
}
|
||||
return operator_div_slow_path(sci);
|
||||
}
|
||||
|
||||
SymInt operator%(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(*ma % *mb);
|
||||
}
|
||||
}
|
||||
return operator_mod_slow_path(sci);
|
||||
}
|
||||
|
||||
void operator*=(const SymInt& sci) {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
*this = SymInt(*ma * *mb);
|
||||
return;
|
||||
}
|
||||
}
|
||||
operator_imul_slow_path(sci);
|
||||
}
|
||||
|
||||
void operator+=(const SymInt& sci) {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
*this = SymInt(*ma + *mb);
|
||||
return;
|
||||
}
|
||||
}
|
||||
operator_iadd_slow_path(sci);
|
||||
}
|
||||
|
||||
void operator/=(const SymInt& sci) {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
*this = SymInt(*ma / *mb);
|
||||
return;
|
||||
}
|
||||
}
|
||||
operator_idiv_slow_path(sci);
|
||||
}
|
||||
|
||||
SymInt clone() const;
|
||||
|
||||
SymBool sym_eq(const SymInt&) const;
|
||||
SymBool sym_ne(const SymInt&) const;
|
||||
SymBool sym_lt(const SymInt&) const;
|
||||
SymBool sym_le(const SymInt&) const;
|
||||
SymBool sym_gt(const SymInt&) const;
|
||||
SymBool sym_ge(const SymInt&) const;
|
||||
SymBool sym_eq(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymBool(*ma == *mb);
|
||||
}
|
||||
}
|
||||
return sym_eq_slow_path(sci);
|
||||
}
|
||||
|
||||
SymBool sym_ne(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymBool(*ma != *mb);
|
||||
}
|
||||
}
|
||||
return sym_ne_slow_path(sci);
|
||||
}
|
||||
|
||||
SymBool sym_lt(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymBool(*ma < *mb);
|
||||
}
|
||||
}
|
||||
return sym_lt_slow_path(sci);
|
||||
}
|
||||
|
||||
SymBool sym_le(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymBool(*ma <= *mb);
|
||||
}
|
||||
}
|
||||
return sym_le_slow_path(sci);
|
||||
}
|
||||
|
||||
SymBool sym_gt(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymBool(*ma > *mb);
|
||||
}
|
||||
}
|
||||
return sym_gt_slow_path(sci);
|
||||
}
|
||||
|
||||
SymBool sym_ge(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymBool(*ma >= *mb);
|
||||
}
|
||||
}
|
||||
return sym_ge_slow_path(sci);
|
||||
}
|
||||
|
||||
bool operator==(const SymInt& o) const {
|
||||
return sym_eq(o).guard_bool(__FILE__, __LINE__);
|
||||
@ -214,8 +328,23 @@ class C10_API SymInt {
|
||||
return sym_ge(o).guard_bool(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
SymInt min(const SymInt& sci) const;
|
||||
SymInt max(const SymInt& sci) const;
|
||||
SymInt min(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(std::min(*ma, *mb));
|
||||
}
|
||||
}
|
||||
return min_slow_path(sci);
|
||||
}
|
||||
|
||||
SymInt max(const SymInt& sci) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
if (auto mb = sci.maybe_as_int()) {
|
||||
return SymInt(std::max(*ma, *mb));
|
||||
}
|
||||
}
|
||||
return max_slow_path(sci);
|
||||
}
|
||||
|
||||
// If both are symbolic, this checks if
|
||||
// they share the same node.
|
||||
@ -260,6 +389,23 @@ class C10_API SymInt {
|
||||
|
||||
private:
|
||||
void promote_to_negative();
|
||||
SymInt operator_add_slow_path(const SymInt& sci) const;
|
||||
SymInt operator_sub_slow_path(const SymInt& sci) const;
|
||||
SymInt operator_mul_slow_path(const SymInt& sci) const;
|
||||
SymInt operator_div_slow_path(const SymInt& sci) const;
|
||||
SymInt operator_mod_slow_path(const SymInt& sci) const;
|
||||
void operator_imul_slow_path(const SymInt& sci);
|
||||
void operator_iadd_slow_path(const SymInt& sci);
|
||||
void operator_idiv_slow_path(const SymInt& sci);
|
||||
SymBool sym_eq_slow_path(const SymInt& sci) const;
|
||||
SymBool sym_ne_slow_path(const SymInt& sci) const;
|
||||
SymBool sym_lt_slow_path(const SymInt& sci) const;
|
||||
SymBool sym_le_slow_path(const SymInt& sci) const;
|
||||
SymBool sym_gt_slow_path(const SymInt& sci) const;
|
||||
SymBool sym_ge_slow_path(const SymInt& sci) const;
|
||||
|
||||
SymInt min_slow_path(const SymInt& sci) const;
|
||||
SymInt max_slow_path(const SymInt& sci) const;
|
||||
|
||||
std::optional<int64_t> maybe_as_int_slow_path() const;
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/ConstantSymNodeImpl.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
@ -35,4 +36,169 @@ TEST(SymIntTest, Overflows) {
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// We need a SymNodeImpl that 1) has working arithmetic with
|
||||
// predictable results and 2) causes SymInt::maybe_as_int to return
|
||||
// nullopt so that we can hit all 4 cases (zero/one/both arguments
|
||||
// have null maybe_as_int) in the operator implementations.
|
||||
class ConstantIntPretendingToBeSymbolicSymNodeImpl
|
||||
: public ConstantSymNodeImpl<int64_t> {
|
||||
public:
|
||||
using ConstantSymNodeImpl<int64_t>::ConstantSymNodeImpl;
|
||||
std::optional<int64_t> constant_int() override {
|
||||
return std::nullopt;
|
||||
}
|
||||
std::optional<int64_t> maybe_as_int() override {
|
||||
return std::nullopt;
|
||||
}
|
||||
// Needs to be implemented for arithmetic to actually
|
||||
// work. NestedIntSymNodeImpl does this, for example.
|
||||
c10::SymNode wrap_int(int64_t num) override {
|
||||
return SymNode(
|
||||
c10::make_intrusive<ConstantIntPretendingToBeSymbolicSymNodeImpl>(num));
|
||||
}
|
||||
|
||||
c10::SymNode wrap_bool(bool b) override {
|
||||
return SymNode(c10::make_intrusive<ConstantSymNodeImpl<bool>>(b));
|
||||
}
|
||||
|
||||
SymNode add(const SymNode& other) override {
|
||||
return wrap_int(int_() + other->int_());
|
||||
}
|
||||
|
||||
SymNode sub(const SymNode& other) override {
|
||||
return wrap_int(int_() - other->int_());
|
||||
}
|
||||
|
||||
SymNode mul(const SymNode& other) override {
|
||||
return wrap_int(int_() * other->int_());
|
||||
}
|
||||
|
||||
SymNode floordiv(const SymNode& other) override {
|
||||
return wrap_int(int_() / other->int_());
|
||||
}
|
||||
|
||||
SymNode sym_min(const SymNode& other) override {
|
||||
return wrap_int(std::min(int_(), other->int_()));
|
||||
}
|
||||
|
||||
SymNode sym_max(const SymNode& other) override {
|
||||
return wrap_int(std::max(int_(), other->int_()));
|
||||
}
|
||||
|
||||
SymNode mod(const SymNode& other) override {
|
||||
return wrap_int(int_() % other->int_());
|
||||
}
|
||||
|
||||
SymNode eq(const SymNode& other) override {
|
||||
return wrap_bool(int_() == other->int_());
|
||||
}
|
||||
|
||||
SymNode ne(const SymNode& other) override {
|
||||
return wrap_bool(int_() != other->int_());
|
||||
}
|
||||
|
||||
SymNode lt(const SymNode& other) override {
|
||||
return wrap_bool(int_() < other->int_());
|
||||
}
|
||||
|
||||
SymNode le(const SymNode& other) override {
|
||||
return wrap_bool(int_() <= other->int_());
|
||||
}
|
||||
|
||||
SymNode gt(const SymNode& other) override {
|
||||
return wrap_bool(int_() > other->int_());
|
||||
}
|
||||
|
||||
SymNode ge(const SymNode& other) override {
|
||||
return wrap_bool(int_() >= other->int_());
|
||||
}
|
||||
};
|
||||
|
||||
SymInt create_symbolic_symint(int64_t value) {
|
||||
return SymInt(
|
||||
SymNode(c10::make_intrusive<ConstantIntPretendingToBeSymbolicSymNodeImpl>(
|
||||
value)));
|
||||
}
|
||||
|
||||
auto unwrap(const SymInt& x) {
|
||||
return x.guard_int(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
auto unwrap(bool b) {
|
||||
return b;
|
||||
}
|
||||
|
||||
template <template <typename> class Op>
|
||||
void test_operator() {
|
||||
for (const auto& arg1 : {SymInt(42), create_symbolic_symint(42)}) {
|
||||
for (const auto& arg2 : {SymInt(27), create_symbolic_symint(27)}) {
|
||||
EXPECT_EQ(unwrap(Op<SymInt>()(arg1, arg2)), Op<int64_t>()(42, 27));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(SymIntTest, BinaryPlus) {
|
||||
test_operator<std::plus>();
|
||||
}
|
||||
|
||||
TEST(SymIntTest, BinaryMinus) {
|
||||
test_operator<std::minus>();
|
||||
}
|
||||
|
||||
TEST(SymIntTest, BinaryMultiplies) {
|
||||
test_operator<std::multiplies>();
|
||||
}
|
||||
|
||||
TEST(SymIntTest, BinaryDivides) {
|
||||
test_operator<std::divides>();
|
||||
}
|
||||
|
||||
TEST(SymIntTest, BinaryModulus) {
|
||||
test_operator<std::modulus>();
|
||||
}
|
||||
|
||||
TEST(SymIntTest, BinaryComparisonOperators) {
|
||||
test_operator<std::equal_to>();
|
||||
test_operator<std::not_equal_to>();
|
||||
test_operator<std::less>();
|
||||
test_operator<std::less_equal>();
|
||||
test_operator<std::greater>();
|
||||
test_operator<std::greater_equal>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct MinWrapper {
|
||||
auto operator()(T lhs, T rhs) const {
|
||||
return std::min(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinWrapper<SymInt> {
|
||||
auto operator()(const SymInt& lhs, const SymInt& rhs) const {
|
||||
return lhs.min(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxWrapper {
|
||||
auto operator()(T lhs, T rhs) const {
|
||||
return std::max(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxWrapper<SymInt> {
|
||||
auto operator()(const SymInt& lhs, const SymInt& rhs) const {
|
||||
return lhs.max(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(SymIntTest, MinMax) {
|
||||
test_operator<MinWrapper>();
|
||||
test_operator<MaxWrapper>();
|
||||
}
|
||||
#endif
|
||||
|
@ -104,10 +104,17 @@ if TEST_WITH_ROCM:
|
||||
test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True)
|
||||
|
||||
|
||||
if os.getenv("BUILD_ENVIRONMENT", "").endswith("-debug"):
|
||||
if any(os.getenv("BUILD_ENVIRONMENT", "").endswith(x) for x in ("-debug", "-asan")):
|
||||
# Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073
|
||||
test_failures["test_resize_as_dynamic_shapes"] = TestFailure(("cpu", "cuda"))
|
||||
test_failures["test_resize_dynamic_shapes"] = TestFailure(("cpu", "cuda"))
|
||||
# After https://github.com/pytorch/pytorch/pull/161586, starts failing UBSAN so we can't even xfail.
|
||||
# Root cause seems to be SymInt issues in StorageImpl, see
|
||||
# https://github.com/pytorch/pytorch/pull/161586#issuecomment-3246530671
|
||||
test_failures["test_resize_as_dynamic_shapes"] = TestFailure(
|
||||
("cpu", "cuda"), is_skip=True
|
||||
)
|
||||
test_failures["test_resize_dynamic_shapes"] = TestFailure(
|
||||
("cpu", "cuda"), is_skip=True
|
||||
)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"):
|
||||
|
Reference in New Issue
Block a user