diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index a3de2aba624e..ec894b7a5f4b 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -095ee628212f0235ad0d6908bdd514123639fc86 +1e9b8bdc75114ac6c16305c970be37a1cd2fdb1c diff --git a/.lintrunner.toml b/.lintrunner.toml index a48d411ea9a8..62b13822e4ad 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -439,7 +439,7 @@ command = [ """--error-description=\ This line has an isinstance call that directly refers to \ int or float. This is error-prone because you may also \ - have wanted to allow SymIntNode or SymFloatNode in your test. \ + have wanted to allow SymInt or SymFloat in your test. \ To suppress this lint, use an appropriate type alias defined \ in torch._prims_common; use IntLike/FloatLike when you would accept \ both regular and symbolic numbers, Dim for ints representing \ diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index e50ffbdcf511..f42c53538990 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -95,7 +95,7 @@ c10::SymInt get_nbytes(const Tensor& value) { if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) { // Today, the two implementations of SymInt are in Python (proxy tensor), // and lazy tensor (LTC/XLA). - // LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl). + // LTC hasn't implemented SymInt support yet though // Once it does, we should remove this check. if (value.key_set().has(c10::DispatchKey::Python)) { return value.storage().sym_nbytes(); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 122afcba4d84..e9a5ea9ec6a2 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -562,7 +562,7 @@ public: IValue(c10::SymInt i) { if (i.is_symbolic()) { tag = Tag::SymInt; - payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release(); + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); } else { tag = Tag::Int; payload.u.as_int = i.as_int_unchecked(); @@ -578,7 +578,7 @@ public: IValue(c10::SymFloat i) { if (i.is_symbolic()) { tag = Tag::SymFloat; - payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release(); + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); } else { tag = Tag::Double; payload.u.as_double = i.as_float_unchecked(); @@ -812,10 +812,10 @@ public: // for both SymFloat and double if (s.isSymInt()) { tag = Tag::SymInt; - payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release(); + payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release(); } else if (s.isSymFloat()) { tag = Tag::SymFloat; - payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release(); + payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release(); } else if (s.isFloatingPoint()) { tag = Tag::Double; payload.u.as_double = s.toDouble(); diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 1c3453abb4c8..bea795c8d81e 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& { inline c10::SymInt IValue::toSymInt() const { AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); if (isSymInt()) { - return c10::SymInt::toSymInt(toIntrusivePtr()); + return c10::SymInt(toIntrusivePtr()); } else { return c10::SymInt(payload.u.as_int); } @@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const { inline c10::SymFloat IValue::toSymFloat() const { AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); if (isSymFloat()) { - return c10::SymFloat::toSymFloat(toIntrusivePtr()); + return c10::SymFloat(toIntrusivePtr()); } else { return c10::SymFloat(payload.u.as_double); } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index e554bd586272..0a8f5e14d9a5 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type { return "SymInt"; } std::string annotation_str_impl(TypePrinter printer = nullptr) const override { - // TODO: will become a Union[SymIntNodeImpl|int] in the near future return "int"; } static const TypeKind Kind = TypeKind::SymIntType; diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index bd9e84bc2355..b6762e173945 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) { ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex(2.0, 3.1)))); ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt()))); } - -TEST(TestSymInt, Basic) { - Scalar foo; - auto a_impl = c10::make_intrusive(); - foo = Scalar(a_impl->toSymInt()); - ASSERT_EQ(a_impl.use_count(), 2); - Scalar bar{foo}; - ASSERT_EQ(a_impl.use_count(), 3); - auto baz = bar; - ASSERT_EQ(a_impl.use_count(), 4); - auto foo2 = std::move(bar); - ASSERT_EQ(a_impl.use_count(), 4); - ASSERT_TRUE(foo2.isSymInt()); - // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) - ASSERT_TRUE(bar.isIntegral(false)); - foo2 = SymInt(4); - ASSERT_FALSE(foo2.isSymInt()); - ASSERT_EQ(foo2.toSymInt().expect_int(), 4); - // NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded) - foo2 = foo2; - ASSERT_FALSE(foo2.isSymInt()); - ASSERT_EQ(foo2.toSymInt().expect_int(), 4); - - ASSERT_EQ(a_impl.use_count(), 3); - - ASSERT_THROW(foo.to(), c10::Error); - - Scalar int_s = 3; - TORCH_CHECK(int_s.toSymInt().expect_int(), 3); - -} diff --git a/build_variables.bzl b/build_variables.bzl index 017ed9aef541..12ad9730123f 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -958,6 +958,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/python_arg_parser.cpp", "torch/csrc/utils/python_dispatch.cpp", + "torch/csrc/utils/python_symnode.cpp", "torch/csrc/utils/structseq.cpp", "torch/csrc/utils/tensor_apply.cpp", "torch/csrc/utils/tensor_dtypes.cpp", diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index c0d89315b65d..0c124177e38f 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -92,8 +92,8 @@ class C10_API Scalar { SymInt toSymInt() const { if (Tag::HAS_si == tag) { - return c10::SymInt::toSymInt(intrusive_ptr::reclaim_copy( - static_cast(v.p))); + return c10::SymInt(intrusive_ptr::reclaim_copy( + static_cast(v.p))); } else { return toLong(); } @@ -101,9 +101,8 @@ class C10_API Scalar { SymFloat toSymFloat() const { if (Tag::HAS_sd == tag) { - return c10::SymFloat::toSymFloat( - intrusive_ptr::reclaim_copy( - static_cast(v.p))); + return c10::SymFloat(intrusive_ptr::reclaim_copy( + static_cast(v.p))); } else { return toDouble(); } diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp index 0ba980a9727e..3c1fea2ee350 100644 --- a/c10/core/SymFloat.cpp +++ b/c10/core/SymFloat.cpp @@ -1,32 +1,27 @@ #include -#include +#include #include namespace c10 { -SymFloatNode SymFloat::toSymFloatNodeImpl() const { +SymNode SymFloat::toSymNodeImpl() const { TORCH_CHECK(is_symbolic()); - return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned()); + return SymNode::reclaim_copy(toSymNodeImplUnowned()); } -static std::array normalize_symfloats( - SymFloat a_, - SymFloat b_) { - SymFloatNode a, b; +static std::array normalize_symfloats(SymFloat a_, SymFloat b_) { + SymNode a, b; if (a_.is_symbolic()) - a = a_.toSymFloatNodeImpl(); + a = a_.toSymNodeImpl(); if (b_.is_symbolic()) - b = b_.toSymFloatNodeImpl(); + b = b_.toSymNodeImpl(); - SymFloatNodeImpl* common = a ? a.get() : b.get(); - // TODO: technically we need to check that the classes match + SymNodeImpl* common = a ? a.get() : b.get(); if (!a) { - a = common->wrap(a_.as_float_unchecked()); - a_.toSymFloat(a); // + a = common->wrap_float(a_.as_float_unchecked()); } if (!b) { - b = common->wrap(b_.as_float_unchecked()); - b_.toSymFloat(b); + b = common->wrap_float(b_.as_float_unchecked()); } return {a, b}; } @@ -36,7 +31,7 @@ SymFloat SymFloat::operator+(SymFloat sci) const { return SymFloat(data_ + sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->add(res[1])); + return SymFloat(res[0]->add(res[1])); } SymFloat SymFloat::operator-(SymFloat sci) const { @@ -44,7 +39,7 @@ SymFloat SymFloat::operator-(SymFloat sci) const { return SymFloat(data_ - sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->sub(res[1])); + return SymFloat(res[0]->sub(res[1])); } SymFloat SymFloat::operator*(SymFloat sci) const { @@ -52,7 +47,7 @@ SymFloat SymFloat::operator*(SymFloat sci) const { return SymFloat(data_ * sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->mul(res[1])); + return SymFloat(res[0]->mul(res[1])); } SymFloat SymFloat::operator/(SymFloat sci) const { @@ -60,16 +55,12 @@ SymFloat SymFloat::operator/(SymFloat sci) const { return SymFloat(data_ / sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->truediv(res[1])); -} - -c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) { - return c10::SymFloat(std::move(sin_sp)); + return SymFloat(res[0]->truediv(res[1])); } std::ostream& operator<<(std::ostream& os, SymFloat s) { if (s.is_symbolic()) { - os << s.toSymFloatNodeImpl()->str(); + os << s.toSymNodeImpl()->str(); } else { os << s.as_float_unchecked(); } diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h index 92abb81ea2a2..b787c020fd75 100644 --- a/c10/core/SymFloat.h +++ b/c10/core/SymFloat.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -14,20 +14,21 @@ namespace c10 { class C10_API SymFloat { public: /*implicit*/ SymFloat(double d) : data_(d){}; - SymFloat(SymFloatNode ptr) - : data_(std::numeric_limits::quiet_NaN()), ptr_(std::move(ptr)){}; + SymFloat(SymNode ptr) + : data_(std::numeric_limits::quiet_NaN()), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_float()); + }; SymFloat() : data_(0.0) {} - SymFloatNodeImpl* toSymFloatNodeImplUnowned() const { + SymNodeImpl* toSymNodeImplUnowned() const { return ptr_.get(); } - SymFloatNodeImpl* release() && { + SymNodeImpl* release() && { return std::move(ptr_).release(); } - SymFloatNode toSymFloatNodeImpl() const; - static c10::SymFloat toSymFloat(SymFloatNode sin); + SymNode toSymNodeImpl() const; double expect_float() const { TORCH_CHECK(!is_symbolic()); @@ -53,7 +54,7 @@ class C10_API SymFloat { private: // TODO: optimize to union double data_; - SymFloatNode ptr_; + SymNode ptr_; }; C10_API std::ostream& operator<<(std::ostream& os, SymFloat s); diff --git a/c10/core/SymFloatNodeImpl.cpp b/c10/core/SymFloatNodeImpl.cpp deleted file mode 100644 index 714ee095d84e..000000000000 --- a/c10/core/SymFloatNodeImpl.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include -#include -#include - -namespace c10 { - -c10::SymFloat SymFloatNodeImpl::toSymFloat() { - auto sit_sp = SymFloatNode::reclaim_copy(this); - return SymFloat::toSymFloat(sit_sp); -} - -c10::SymIntNode SymFloatNodeImpl::ceil() { - TORCH_CHECK(false, "NYI"); -} - -c10::SymIntNode SymFloatNodeImpl::floor() { - TORCH_CHECK(false, "NYI"); -} - -} // namespace c10 diff --git a/c10/core/SymFloatNodeImpl.h b/c10/core/SymFloatNodeImpl.h deleted file mode 100644 index 0ab9d952b5bb..000000000000 --- a/c10/core/SymFloatNodeImpl.h +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace c10 { - -class SymIntNodeImpl; -using SymIntNode = c10::intrusive_ptr; - -class SymFloat; -class SymFloatNodeImpl; -using SymFloatNode = c10::intrusive_ptr; - -class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target { - public: - c10::SymFloat toSymFloat(); - virtual ~SymFloatNodeImpl(){}; - - template - c10::intrusive_ptr dyn_cast() const { - return c10::intrusive_ptr::reclaim_copy(dynamic_cast(this)); - } - - virtual SymFloatNode wrap(double num) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode add(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode sub(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode mul(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode truediv(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode pow(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode eq(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode ne(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode gt(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode lt(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode le(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode ge(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymIntNode ceil(); - virtual SymIntNode floor(); - virtual std::string str() { - TORCH_CHECK(false, "NYI"); - }; - std::ostream& operator<<(std::ostream& os) { - os << str(); - return os; - }; -}; - -} // namespace c10 diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index 5ef576b3af1b..b32157e4a94e 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -1,47 +1,46 @@ #include #include -#include +#include #include namespace c10 { -static std::array normalize_symints(SymInt a_, SymInt b_) { - SymIntNode a, b; +static std::array normalize_symints(SymInt a_, SymInt b_) { + SymNode a, b; if (a_.is_symbolic()) - a = a_.toSymIntNodeImpl(); + a = a_.toSymNodeImpl(); if (b_.is_symbolic()) - b = b_.toSymIntNodeImpl(); + b = b_.toSymNodeImpl(); - SymIntNodeImpl* common = a ? a.get() : b.get(); + SymNodeImpl* common = a ? a.get() : b.get(); // TODO: technically we need to check that the classes match if (!a) { - a = common->wrap(a_.as_int_unchecked()); - a_.toSymInt(a); // + a = common->wrap_int(a_.as_int_unchecked()); } if (!b) { - b = common->wrap(b_.as_int_unchecked()); - b_.toSymInt(b); + b = common->wrap_int(b_.as_int_unchecked()); } return {a, b}; } -SymIntNode SymInt::toSymIntNodeImpl() const { +SymNode SymInt::toSymNodeImpl() const { TORCH_CHECK(is_symbolic()); - return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned()); + return SymNode::reclaim_copy(toSymNodeImplUnowned()); } -c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) { +SymInt::SymInt(SymNode sin_sp) { + TORCH_CHECK(sin_sp->is_int()); auto ptr = static_cast( reinterpret_cast(static_cast(sin_sp.release()))); auto rep = (ptr & ~MASK) | IS_SYM; - return c10::SymInt(UNCHECKED, static_cast(rep)); + data_ = static_cast(rep); } int64_t SymInt::guard_int(const char* file, int64_t line) const { if (!is_symbolic()) { return data_; } - SymIntNode a = toSymIntNodeImpl(); + SymNode a = toSymNodeImpl(); return a->guard_int(file, line); } @@ -49,7 +48,7 @@ SymInt::operator SymFloat() const { if (!is_symbolic()) { return SymFloat(double(data_)); } - return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float()); + return SymFloat(toSymNodeImpl()->sym_float()); } SymInt SymInt::operator+(SymInt sci) const { @@ -57,7 +56,7 @@ SymInt SymInt::operator+(SymInt sci) const { return SymInt(data_ + sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->add(res[1])); + return SymInt(res[0]->add(res[1])); } SymInt SymInt::operator-(SymInt sci) const { @@ -65,7 +64,7 @@ SymInt SymInt::operator-(SymInt sci) const { return SymInt(data_ - sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->sub(res[1])); + return SymInt(res[0]->sub(res[1])); } SymInt SymInt::operator*(SymInt sci) const { @@ -73,7 +72,7 @@ SymInt SymInt::operator*(SymInt sci) const { return SymInt(data_ * sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->mul(res[1])); + return SymInt(res[0]->mul(res[1])); } SymInt SymInt::operator/(SymInt sci) const { @@ -81,7 +80,7 @@ SymInt SymInt::operator/(SymInt sci) const { return SymInt(data_ / sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->floordiv(res[1])); + return SymInt(res[0]->floordiv(res[1])); } SymInt SymInt::operator%(SymInt sci) const { @@ -89,7 +88,7 @@ SymInt SymInt::operator%(SymInt sci) const { return SymInt(data_ % sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->mod(res[1])); + return SymInt(res[0]->mod(res[1])); } bool SymInt::operator==(SymInt sci) const { @@ -141,14 +140,14 @@ SymInt SymInt::min(SymInt sci) const { return std::min(data_, sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->min(res[1])); + return SymInt(res[0]->min(res[1])); } SymInt SymInt::max(SymInt sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return std::max(data_, sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->max(res[1])); + return SymInt(res[0]->max(res[1])); } void SymInt::operator*=(SymInt sci) { @@ -193,7 +192,7 @@ SymInt SymInt::operator*(int64_t sci) const { std::ostream& operator<<(std::ostream& os, SymInt s) { if (s.is_symbolic()) { - os << s.toSymIntNodeImpl()->str(); + os << s.toSymNodeImpl()->str(); } else { os << s.as_int_unchecked(); } @@ -202,7 +201,7 @@ std::ostream& operator<<(std::ostream& os, SymInt s) { SymInt operator-(SymInt s) { if (s.is_symbolic()) { - return SymInt::toSymInt(s.toSymIntNodeImpl()->neg()); + return SymInt(s.toSymNodeImpl()->neg()); } else { return SymInt(-s.as_int_unchecked()); } diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 6934a607ccbf..a10775196d86 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -12,24 +12,19 @@ namespace c10 { class SymFloat; -// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to -// represent concrete dimension values. +// SymInt represents either a regular int64_t, or a symbolic integer +// (represented in a type erased way as SymNode). The intention is for SymInt +// to represent symbolic sizes that arise when doing shape computation in +// operator kernels. This allows for tracing through programs without baking in +// concrete sizes into kernel calls. // -// `SymInt` is also a data type in Pytorch that can be used in function schemas -// to enable tracing. +// SymInt has an API equivalent to int64_t. In particular, it is a value type. +// Internally, SymInt is represented in a clever packed way, so that it only +// occupies one word of space; but morally, it is a union between an int64_t +// and an intrusive pointer to SymNodeImpl. // -// `SymInt` is introduced to enable tracing arithmetic -// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will -// allow LTC and AOTAutograd representing dynamic shapes in expression graphs -// faithfully without baking in concrete dimension values. -// -// To trace the operations, SymInt will overload arithmetic operators (e.g. +, -// -, *) and will provide overloads taking SymInt for commonly used math -// functions. -// -// SymInt will be extenteded to represent a union structure Union[int64_t, -// SymIntNodeImpl*] which will be implemented as a single packed int64_t field -// named data_. +// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where +// is_int() returns true class C10_API SymInt { public: @@ -44,6 +39,7 @@ class C10_API SymInt { TORCH_CHECK(!is_symbolic()); }; SymInt() : data_(0) {} + SymInt(SymNode n); // unchecked c-tor accepting raw `data_` // One appropriate use for this is when you are constructing a symint @@ -55,7 +51,7 @@ class C10_API SymInt { // temporary and then use the move constructor/assignment SymInt(const SymInt& s) : data_(0) { if (s.is_symbolic()) { - *this = SymInt::toSymInt(s.toSymIntNodeImpl()); + *this = SymInt(s.toSymNodeImpl()); } else { data_ = s.data_; } @@ -67,7 +63,7 @@ class C10_API SymInt { SymInt& operator=(const SymInt& s) { if (this != &s) { if (s.is_symbolic()) { - *this = SymInt::toSymInt(s.toSymIntNodeImpl()); + *this = SymInt(s.toSymNodeImpl()); } else { data_ = s.data_; } @@ -76,7 +72,7 @@ class C10_API SymInt { } SymInt& operator=(SymInt&& s) { if (this != &s) { - release_(); // release the current SymIntNode if any + release_(); // release the current SymNode if any data_ = s.data_; if (s.is_symbolic()) s.data_ = 0; @@ -86,31 +82,31 @@ class C10_API SymInt { SymInt clone() const { if (is_symbolic()) { - return toSymIntNodeImplUnowned()->clone()->toSymInt(); + return SymInt(toSymNodeImplUnowned()->clone()); } return *this; } - SymIntNodeImpl* toSymIntNodeImplUnowned() const { + SymNodeImpl* toSymNodeImplUnowned() const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic()); uint64_t unextended_bits = static_cast(data_) & ~MASK; uint64_t sign_bit_mask = 1ULL << (62 - 1); // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; - return static_cast( + return static_cast( reinterpret_cast(static_cast(extended_bits))); } void release_() { if (is_symbolic()) { - SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal + SymNode::reclaim(toSymNodeImplUnowned()); // steal } } - SymIntNodeImpl* release() && { + SymNodeImpl* release() && { #ifndef C10_MOBILE TORCH_INTERNAL_ASSERT(is_symbolic()); - auto* r = toSymIntNodeImplUnowned(); + auto* r = toSymNodeImplUnowned(); data_ = 0; // transfer ownership return r; #else @@ -118,8 +114,7 @@ class C10_API SymInt { #endif } - SymIntNode toSymIntNodeImpl() const; - static c10::SymInt toSymInt(SymIntNode sin); + SymNode toSymNodeImpl() const; ~SymInt() { release_(); diff --git a/c10/core/SymIntNodeImpl.cpp b/c10/core/SymIntNodeImpl.cpp deleted file mode 100644 index 483110a90fa6..000000000000 --- a/c10/core/SymIntNodeImpl.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include - -namespace c10 { - -c10::SymInt SymIntNodeImpl::toSymInt() { - auto sit_sp = SymIntNode::reclaim_copy(this); - return SymInt::toSymInt(sit_sp); -} - -} // namespace c10 diff --git a/c10/core/SymNodeImpl.cpp b/c10/core/SymNodeImpl.cpp new file mode 100644 index 000000000000..80999ba50f1e --- /dev/null +++ b/c10/core/SymNodeImpl.cpp @@ -0,0 +1,3 @@ +#include + +namespace c10 {} // namespace c10 diff --git a/c10/core/SymIntNodeImpl.h b/c10/core/SymNodeImpl.h similarity index 50% rename from c10/core/SymIntNodeImpl.h rename to c10/core/SymNodeImpl.h index 0b9d4c557928..d2f3aafaad8b 100644 --- a/c10/core/SymIntNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -10,13 +9,12 @@ namespace c10 { -class SymInt; -class SymIntNodeImpl; +class SymNodeImpl; +using SymNode = c10::intrusive_ptr; -class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target { +class C10_API SymNodeImpl : public c10::intrusive_ptr_target { public: - c10::SymInt toSymInt(); - virtual ~SymIntNodeImpl(){}; + virtual ~SymNodeImpl(){}; template c10::intrusive_ptr dyn_cast() const { @@ -24,66 +22,87 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target { } // these could be pure virtual when we implement LTC versions - virtual SymIntNode add(const SymIntNode& other) { + virtual bool is_int() { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode sub(const SymIntNode& other) { + virtual bool is_float() { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode mul(const SymIntNode& other) { + virtual SymNode add(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymFloatNode truediv(const SymIntNode& other) { + virtual SymNode sub(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode floordiv(const SymIntNode& other) { + virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode mod(const SymIntNode& other) { + virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode eq(const SymIntNode& other) { + virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode ne(const SymIntNode& other) { + virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode gt(const SymIntNode& other) { + virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode lt(const SymIntNode& other) { + virtual SymNode eq(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode le(const SymIntNode& other) { + virtual SymNode ne(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode ge(const SymIntNode& other) { + virtual SymNode gt(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode ceil() { + virtual SymNode lt(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode neg() { + virtual SymNode le(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode min(const SymIntNode& other) { + virtual SymNode ge(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode max(const SymIntNode& other) { + virtual SymNode ceil() { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode clone() { + virtual SymNode floor() { TORCH_CHECK(false, "NYI"); }; - virtual SymFloatNode sym_float() { + virtual SymNode neg() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode min(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode max(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode clone() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_int() { TORCH_CHECK(false, "NYI"); } - virtual SymIntNode wrap(int64_t num) { + virtual SymNode sym_float() { + TORCH_CHECK(false, "NYI"); + } + virtual SymNode wrap_int(int64_t num) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode wrap_float(double num) { TORCH_CHECK(false, "NYI"); }; virtual int64_t guard_int(const char* file, int64_t line) { TORCH_CHECK(false, "NYI"); }; + virtual double guard_float(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + }; virtual int64_t int_() { TORCH_CHECK(false, "NYI"); }; diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index a57e7c706486..d889d72b5afb 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include using namespace c10; #ifndef C10_MOBILE @@ -20,12 +20,6 @@ TEST(SymIntTest, ConcreteInts) { check(-4611686018427387904LL); } -TEST(SymIntTest, AddNode) { - auto n = c10::make_intrusive(); - auto i = n->toSymInt(); - EXPECT_TRUE(i.is_symbolic()); -} - TEST(SymIntTest, CheckRange) { EXPECT_FALSE(SymInt::check_range(INT64_MIN)); } diff --git a/docs/source/conf.py b/docs/source/conf.py index 8c0eac82cf99..807f486ac0d6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -335,8 +335,8 @@ coverage_ignore_classes = [ "Quantize", # torch.utils.backcompat "Warning", - "SymIntNode", - "SymFloatNode", + "SymInt", + "SymFloat", ] # The suffix(es) of source filenames. diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index b1e29b6ac410..d4663c6dc71a 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -605,7 +605,7 @@ class PytreeThunk: return x return pytree.tree_unflatten(x, self.spec) -KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymIntNode, torch.SymFloatNode] +KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymInt, torch.SymFloat] def aot_function( diff --git a/functorch/_src/partitioners.py b/functorch/_src/partitioners.py index 1077904528ef..c82afe65787b 100644 --- a/functorch/_src/partitioners.py +++ b/functorch/_src/partitioners.py @@ -209,7 +209,7 @@ def _tensor_nbytes(numel, dtype): def _size_of(node: fx.Node) -> int: def to_size_hint(s): - if isinstance(s, torch.SymIntNode): + if isinstance(s, torch.SymInt): py_s = s.get_pyobj() return py_s.shape_env.size_hint(py_s.expr) assert isinstance(s, int) diff --git a/functorch/experimental/cond.py b/functorch/experimental/cond.py index 6f7bcbf506d8..e620dbadeccb 100644 --- a/functorch/experimental/cond.py +++ b/functorch/experimental/cond.py @@ -18,6 +18,8 @@ cond = PyOperator('cond') def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): def _unwrap_proxy(e): + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): + return e return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) assert isinstance(operands, list), "Cond operands must be a list of tensors" diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 6e3283f62a5b..2aac6cacdffc 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1447,35 +1447,29 @@ TEST(TestSymInt, AddSymbolicInt) { } #ifndef C10_MOBILE -TEST(TestSymInt, TestIntrusive) { - auto a = c10::make_intrusive(); - auto b = c10::make_intrusive(); - ASSERT_EQ(a.use_count(), 1); - ASSERT_EQ(b.use_count(), 1); - auto as = a->toSymInt(); - auto bs = b->toSymInt(); - ASSERT_EQ(a.use_count(), 2); - ASSERT_EQ(b.use_count(), 2); - as = bs; - ASSERT_EQ(a.use_count(), 1); - ASSERT_EQ(b.use_count(), 3); -} - -class TestSymIntNodeImpl : public c10::SymIntNodeImpl { +class TestSymNodeImpl : public c10::SymNodeImpl { public: - TestSymIntNodeImpl(int64_t i) : i_(i) {} + explicit TestSymNodeImpl(int64_t i) : i_(i) {} + + bool is_int() override { + return true; + }; + + bool is_float() override { + return false; + }; bool bool_() override { return static_cast(i_); }; -#define OPDEF3(NAME, OP, RET) \ - RET NAME(const c10::SymIntNode& other) override { \ - return make_intrusive( \ - this->i_ OP dynamic_cast(other.get())->i_); \ +#define OPDEF3(NAME, OP, RET) \ + RET NAME(const c10::SymNode& other) override { \ + return make_intrusive( \ + this->i_ OP dynamic_cast(other.get())->i_); \ } -#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymIntNode) +#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode) OPDEF2(add, +) OPDEF2(sub, -) OPDEF2(mul, *) @@ -1494,17 +1488,19 @@ class TestSymIntNodeImpl : public c10::SymIntNodeImpl { int64_t i_; }; -TEST(TestSymInt, TestSymIntToSymIntNodeDispatch) { +TEST(TestSymInt, TestSymIntToSymNodeDispatch) { auto get = [](c10::SymInt si) { - auto node = si.toSymIntNodeImpl(); - return dynamic_cast(node.get())->i_; + auto node = si.toSymNodeImpl(); + return dynamic_cast(node.get())->i_; }; std::vector inputs{0, 1, -1, 4, -4, 777, -777}; for (auto i : inputs) { for (auto j : inputs) { - auto a = c10::make_intrusive(i)->toSymInt(); - auto b = c10::make_intrusive(j)->toSymInt(); + auto a = c10::SymInt( + static_cast(c10::make_intrusive(i))); + auto b = c10::SymInt( + static_cast(c10::make_intrusive(j))); ASSERT_EQ(get(a + b), i + j); ASSERT_EQ(get(a - b), i - j); ASSERT_EQ(get(a * b), i * j); diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b183b6169dd6..0e85b54cfe3f 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -12,8 +12,9 @@ import itertools import io from torch.utils._pytree import tree_map from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float +from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode from torch.utils._python_dispatch import TorchDispatchMode +from torch import SymInt aten = torch.ops.aten @@ -116,9 +117,6 @@ def create_symbolic_tensor(name, arg, shape_env, storage_offset=0): sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg) return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset) - -CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1)) - def create_symint(shape_env, i): return shape_env.create_symintnode(shape_env.create_symbol(i)) @@ -156,8 +154,8 @@ class TestPySymInt(TestCase): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - self.assertTrue(not isinstance(x.shape[0], PySymInt)) - self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) + self.assertTrue(not isinstance(x.shape[0], SymNode)) + self.assertTrue(isinstance(x.shape[0], SymInt)) self.assertTrue(x.shape[0] == 5) self.assertTrue(x.shape[1] == 4) @@ -165,17 +163,17 @@ class TestPySymInt(TestCase): self.assertTrue(x.size()[0], 5) self.assertTrue(x.size()[1], 4) - self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) + self.assertTrue(isinstance(x.size()[1], SymInt)) self.assertTrue(x.size()[2] == 3) self.assertTrue(x.size(0) == 5) self.assertTrue(x.size(1) == 4) self.assertTrue(x.size(2) == 3) - self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) + self.assertTrue(isinstance(x.size(2), SymInt)) offset = create_symint(shape_env, 2) y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS)) + self.assertTrue(isinstance(y.storage_offset(), SymInt)) self.assertTrue(y.storage_offset() == 2) offset = 2 @@ -267,7 +265,7 @@ class TestPySymInt(TestCase): def test_stride(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) - self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS) + self.assertIsInstance(x.stride()[0], SymInt) @skipIfNoSympy def test_size_expressions(self): @@ -290,7 +288,7 @@ class TestPySymInt(TestCase): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) r = sym_float(x.shape[0]) - self.assertTrue(isinstance(r, torch.SymFloatNode)) + self.assertIsInstance(r, torch.SymFloat, msg=type(r)) @skipIfNoSympy def test_aten_ops(self): @@ -320,13 +318,13 @@ class TestPySymInt(TestCase): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) r = torch.empty(a0, device='meta') - self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) + self.assertIsInstance(r.shape[0], SymInt) @skipIfNoSympy def test_guard_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) - self.assertEqual(a0.guard_int(), 2) + self.assertEqual(guard_int(a0), 2) self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)") @skipIfNoSympy @@ -347,7 +345,9 @@ class TestPySymInt(TestCase): assert func == torch.ops.aten.add.Tensor nonlocal sym_int_encountered - sym_int_encountered = kwargs["alpha"] is a0 + # WARNING: do not do identity tests on the outer + # SymInt/SymFloat, they are NOT STABLE + sym_int_encountered = kwargs["alpha"].node is a0.node kwargs["alpha"] = 0 return func(*args) diff --git a/test/test_dynamic_shapes.py.bak b/test/test_dynamic_shapes.py.bak deleted file mode 100644 index 19c77fe4d7ab..000000000000 --- a/test/test_dynamic_shapes.py.bak +++ /dev/null @@ -1,391 +0,0 @@ -# -*- coding: utf-8 -*- -# Owner(s): ["oncall: jit"] - -from torch._C import _disabled_torch_function_impl -import torch.fx -import torch.nn.functional as F -from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo -import unittest -import torch -import operator -import itertools -import io -from torch.utils._pytree import tree_map -from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float -from torch.utils._python_dispatch import TorchDispatchMode - -aten = torch.ops.aten - -try: - import sympy - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") - - -meta_funcs = {} - - -def register_meta(op): - def decorator(f): - def add_func(op): - meta_funcs[op] = f - tree_map(add_func, op) - return f - return decorator - - -@register_meta([aten.add.Tensor, aten.sub.Tensor]) -def binary_meta(a, b): - return a.new_empty(a.shape) - - -@register_meta(aten.cat.default) -def cat_meta(tensors, dim=0): - concat_length = 0 - shape = tensors[0].shape - for tensor in tensors: - for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): - if idx == dim: - concat_length = concat_length + length - else: - assert length == common_length - new_shape = list(shape) - new_shape[dim] = concat_length - return tensors[0].new_empty(new_shape) - - -@register_meta([aten.narrow_copy.default]) -def narrow_copy_symint_meta(a, dim, start, length, **kwargs): - shape = [] - for i, x in enumerate(a.shape): - if i == dim: - shape.append(length) - else: - shape.append(x) - return a.new_empty(tuple(shape)) - - -@register_meta([aten.expand.default]) -def expand_symint_meta(a, size, implicit=False): - return a.new_empty(size) - - -def create_contiguous(shape): - strides = [1] - for dim in reversed(shape[:-1]): - strides.append(dim * strides[-1]) - return list(reversed(strides)) - - -class FakeSymbolicTensor(torch.Tensor): - @staticmethod - def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0): - # TODO: this is wrong in general - sym_stride = create_contiguous(sym_shape) - r = torch.Tensor._make_wrapper_subclass( - cls, sym_shape, - sym_stride, storage_offset, - dtype=dtype, layout=layout, requires_grad=requires_grad, - device=device, - ) - return r - - __torch_function__ = _disabled_torch_function_impl - - def new_empty(self, shape): - return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device) - - @classmethod - def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): - if func_overload in meta_funcs: - return meta_funcs[func_overload](*args, **kwargs) - - if func_overload == torch.ops.aten.new_empty.default: - self = args[0] - shape = args[1] - return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device) - - raise RuntimeError(f"operator {func_overload} not supported") - - -def create_symbolic_tensor(name, arg, shape_env, storage_offset=0): - sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg) - return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset) - - -CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1)) - -def create_symint(shape_env, i): - return shape_env.create_symintnode(shape_env.create_symbol(i)) - -@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") -class TestPySymInt(TestCase): - - @skipIfNoSympy - def test_arith_ops(self): - shape_env = ShapeEnv() - symints = [] - for i in range(2, 5): - symints.append((i, create_symint(shape_env, i))) - - ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod] - - for op in ops: - for args in itertools.permutations(symints, 2): - if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0): - self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])) - - - @skipIfNoSympy - def test_reverse_arith_ops(self): - shape_env = ShapeEnv() - - a = create_symint(shape_env, 2) - self.assertTrue(5 // a == 5 // 2) - - a = create_symint(shape_env, 2) - self.assertTrue(5 * a == 5 * 2) - - - @skipIfNoSympy - def test_roundtrip(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - - self.assertTrue(not isinstance(x.shape[0], PySymInt)) - self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) - - self.assertTrue(x.shape[0] == 5) - self.assertTrue(x.shape[1] == 4) - self.assertTrue(x.shape[2], 3) - - self.assertTrue(x.size()[0], 5) - self.assertTrue(x.size()[1], 4) - self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) - self.assertTrue(x.size()[2] == 3) - - self.assertTrue(x.size(0) == 5) - self.assertTrue(x.size(1) == 4) - self.assertTrue(x.size(2) == 3) - self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) - - offset = create_symint(shape_env, 2) - y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS)) - self.assertTrue(y.storage_offset() == 2) - - offset = 2 - z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(z.storage_offset(), int)) - self.assertTrue(z.storage_offset() == 2) - - @skipIfNoSympy - def test_binary(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) - - z = x + y - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # broadcasting - y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) - z = x + y - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - @skipIfNoSympy - def test_symint_args(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) - LAST_DIM = 2 - z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) - self.assertTrue(z.shape[2] == y.shape[2]) - - # arithmetic expr with two symints - z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) - self.assertTrue(z.shape[2] == 2) - - # arithmetic expr with a symint and python int - z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) - self.assertTrue(z.shape[2] == 2) - - @skipIfNoSympy - def test_symint_vargs(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) - - # varargs - z = y.expand(x.shape[0], y.shape[1], x.shape[2]) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # shape list - z = y.expand((x.shape[0], y.shape[1], x.shape[2])) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python symints and ints - z = y.expand(x.shape[0], y.shape[1], 3) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python symints and ints in a list - z = y.expand((x.shape[0], y.shape[1], 3)) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python symints and ints - z = y.expand(5, y.shape[1], x.shape[2]) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python ints and symints in a list - z = y.expand((5, y.shape[1], x.shape[2])) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - z = y.expand((y.shape[1],)) - z = y.expand(y.shape[1]) - - @skipIfNoSympy - def test_stride(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) - self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS) - - @skipIfNoSympy - def test_size_expressions(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5), shape_env) - expand_x = x.expand(x.shape[0], x.shape[0]) - if expand_x.shape[0] > 3: - result = expand_x + expand_x - else: - result = expand_x + expand_x - - gt_op = shape_env.guards[0][0] - self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) - self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) - self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) - self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) - - @skipIfNoSympy - def test_int_to_float(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5), shape_env) - r = sym_float(x.shape[0]) - self.assertTrue(isinstance(r, torch.SymFloatNode)) - - @skipIfNoSympy - def test_aten_ops(self): - - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5), shape_env) - torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) - - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) - - def test_fx_trace_intlist(self): - class CustomModule(torch.nn.Module): - def forward(self, x): - bs, c, h, w = x.shape - return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) - - m = CustomModule() - x = torch.rand(1, 3, 4, 4) - # should not TypeError: pad(): argument 'pad' (position 2) must be - # tuple of ints, not tuple - torch.fx.symbolic_trace(m) - - @skipIfNoSympy - def test_meta_symint(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - r = torch.empty(a0, device='meta') - self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) - - @skipIfNoSympy - def test_guard_int(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - self.assertEqual(a0.guard_int(), 2) - self.assertEqual(str(shape_env.guards[0][0]), "s0") - self.assertEqual(shape_env.guards[0][1], 2) - - @skipIfNoSympy - def test_int_conversion(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0)) - - @skipIfNoSympy - def test_symint_as_scalar(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - - sym_int_encountered = False - - class TestSymInt(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - assert func == torch.ops.aten.add.Tensor - - nonlocal sym_int_encountered - sym_int_encountered = kwargs["alpha"] is a0 - kwargs["alpha"] = 0 - return func(*args) - - x = torch.rand([4, 4]) - with TestSymInt(): - y = torch.add(x, x, alpha=a0) - - self.assertTrue(sym_int_encountered) - - @skipIfNoSympy - @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) - def test_print_readable_with_symints(self, mock_stdout): - def f(a, b): - dim0 = a.shape[0] + b.shape[0] - dim1 = a.shape[1] + b.shape[1] - d = a.new_empty(dim0, dim1) - d = torch.ops.aten.native_dropout(d, 0.5, train=True) - return d - - fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) - fx_g.print_readable() - - self.assertExpectedInline(mock_stdout.getvalue().strip(), """\ -class f(torch.nn.Module): - def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]): - # No stacktrace found for following nodes - sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0) - sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0) - add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None - sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1) - sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None - add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None - new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None - native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None - getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0] - getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None - return (getitem, getitem_1)""") # noqa: B950 - - -if __name__ == '__main__': - run_tests() diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 1d5985a00da8..6cb7d280cc19 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -875,8 +875,7 @@ def forward(self, a_1): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size = torch.ops.aten.sym_size(a_1, 0) - sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None - pow_1 = sym_float ** 0.5; sym_float = None + pow_1 = sym_size ** 0.5; sym_size = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") @@ -949,7 +948,7 @@ def forward(self, a_1): fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) meta_d = _get_node(fx_g, lambda x: x.target == operator.add) - self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr) + self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].node.expr) def test_metadata_fresh(self): def f(x): diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 5215281b7ac6..c4a64b5cb647 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -207,8 +207,8 @@ class TestPublicBindings(TestCase): "StreamObjType", "StringType", "SUM", - "SymFloatNode", - "SymIntNode", + "SymFloat", + "SymInt", "TensorType", "ThroughputBenchmark", "TracingState", diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 7b120593eb53..3e9e125bfb9f 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -291,7 +291,7 @@ PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); for (auto i : c10::irange(prop.size())) { auto si = prop[i]; if (si.is_symbolic()) { - auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); + auto py_symint = py::cast(si).release().ptr(); PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); } else { PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked())); @@ -313,7 +313,7 @@ return PyLong_FromUnsignedLong((int64_t) prop); """ GETTER_BODY_SYMINT = """\ -return prop.is_symbolic() ? py::cast(prop.toSymIntNodeImpl()).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked()); +return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked()); """ GETTER_BODY_DOUBLE = """\ diff --git a/tools/autograd/templates/python_functions.cpp b/tools/autograd/templates/python_functions.cpp index 57343a53ea98..eacf56b31d88 100644 --- a/tools/autograd/templates/python_functions.cpp +++ b/tools/autograd/templates/python_functions.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include "torch/csrc/autograd/generated/Functions.h" #include "torch/csrc/autograd/python_cpp_function.h" #include diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index e4df2a8dc61d..7122532a5441 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -240,12 +240,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args) if (jit::tracer::isTracing()) { return wrap(jit::tracer::getNumelOf(self_)); } else { - auto si = self_.sym_numel(); - if (si.is_symbolic()) { - return py::cast(si.toSymIntNodeImpl()).release().ptr(); - } else { - return THPUtils_packInt64(si.as_int_unchecked()); - } + return py::cast(self_.sym_numel()).release().ptr(); } END_HANDLE_TH_ERRORS } diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 417d73f829a6..0d1cdcb4ad06 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -722,7 +722,7 @@ def gen_pyi( binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( - "def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode]{})" + "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})" " -> Tensor: ...".format(binop, out_suffix) ) for binop in ["add", "sub"]: @@ -732,7 +732,7 @@ def gen_pyi( binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( - "def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode], " + "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], " "*, alpha: Optional[Number]=1{})" " -> Tensor: ...".format(binop, out_suffix) ) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 792e23199916..8b5a5d8e83b3 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -169,20 +169,6 @@ class Future(object): def _jit_set_num_profiled_runs(num: _size) -> _size: ... -class SymIntNode(object): - def get_pyobj(self) -> Any: ... - - @staticmethod - def new_symint(obj) -> SymIntNode: ... - -class SymFloatNode(object): - def get_pyobj(self) -> Any: ... - - @staticmethod - def new_symfloat(obj) -> SymFloatNode: ... - - def __ceil__(self) -> SymIntNode: ... - # Defined in torch/csrc/jit/passes/xnnpack_rewrite.h class MobileOptimizerType: ... diff --git a/torch/__init__.py b/torch/__init__.py index 63995d6ec7f6..c2f2c4c3327f 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -47,7 +47,7 @@ __all__ = [ 'is_deterministic_algorithms_warn_only_enabled', 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_float32_matmul_precision', 'get_float32_matmul_precision', - 'set_warn_always', 'is_warn_always_enabled', + 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', ] ################################################################################ @@ -196,6 +196,67 @@ else: if TYPE_CHECKING: import torch._C as _C +class SymInt: + """ + Like an int (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + """ + + def __init__(self, node): + from torch.fx.experimental.symbolic_shapes import SymNode + assert isinstance(node, SymNode) + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + # Magic methods installed later + + def __bool__(self): + return self.node.bool_() + + def __int__(self): + return self.node.int_() + + def __sym_float__(self): + return SymFloat(self.node.sym_float()) + + def __repr__(self): + return self.node.str() + + # For BC; direct access of node is OK too + def get_pyobj(self): + return self.node + +class SymFloat: + """ + Like an float (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + """ + + def __init__(self, node): + from torch.fx.experimental.symbolic_shapes import SymNode + assert isinstance(node, SymNode) + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + # Magic methods installed later + + def __bool__(self): + return self.node.bool_() + + def __sym_int__(self): + return SymInt(self.node.sym_int()) + + def __repr__(self): + return self.node.str() + + # For BC; direct access of node is OK too + def get_pyobj(self): + return self.node + # Check to see if we can load C extensions, and if not provide some guidance # on what the problem might be. try: @@ -941,7 +1002,6 @@ from ._linalg_utils import ( # type: ignore[misc] lstsq, ) - def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` supported by torch. @@ -971,3 +1031,6 @@ if 'TORCH_CUDA_SANITIZER' in os.environ: import torch.cuda._sanitizer as csan csan.enable_cuda_sanitizer() + +# Populate magic methods on SymInt and SymFloat +import torch.fx.experimental.symbolic_shapes diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 864d2c4ca3e0..ab4cbf62ce36 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -337,7 +337,7 @@ class TensorVariable(VariableTracker): from . import UserDefinedObjectVariable return UserDefinedObjectVariable(example_value) - elif isinstance(example_value, torch.SymIntNode): + elif isinstance(example_value, torch.SymInt): proxy.node.meta["example_value"] = example_value return cls(proxy, **options) else: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 3e274be50615..71934419de2f 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -40,11 +40,9 @@ class GraphLowering(torch.fx.Interpreter): else: size, stride = self._shape_env.create_symbolic_sizes_strides(ex) - size = [ - i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in size - ] + size = [i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in size] stride = [ - i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in stride + i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in stride ] return size, stride diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index b54019ef031c..bf71b0069585 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -392,8 +392,8 @@ def _elementwise_meta( # Number case # NOTE: this case is not currently exercised # TODO: fix number type promotion (bool, complex->float) - assert not isinstance(number, torch.SymIntNode), "NYI" - assert not isinstance(number, torch.SymFloatNode), "NYI" + assert not isinstance(number, torch.SymInt), "NYI" + assert not isinstance(number, torch.SymFloat), "NYI" return TensorMeta(number) @@ -932,7 +932,7 @@ bitwise_xor = _make_elementwise_binary_prim( # div prim performs truncation division on integer inputs # and true division for floating and complex inputs def _div_aten(a, b): - is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or ( + is_integral = isinstance(a, (bool, int, torch.SymInt)) or ( isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype) ) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index d8321ac9a47c..ee4dd38a655c 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -42,18 +42,18 @@ ShapeType = Union[torch.Size, List[int], Tuple[int, ...]] StrideType = Union[List[int], Tuple[int, ...]] DimsType = Union[int, List[int], Tuple[int, ...]] DimsSequenceType = Union[List[int], Tuple[int, ...]] -# TODO: Type[torch.SymIntNode], Type[torch.SymFloatNode] +# TODO: Type[torch.SymInt], Type[torch.SymFloat] NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] # TODO: This needs a lot more type annotations -# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode] +# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] NumberType = Union[bool, int, float, complex] -Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode) +Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat) # I don't call it Integral because numbers.Integral includes bool, but IntLike # does not Dim = int -IntLike = (int, torch.SymIntNode) -FloatLike = (float, torch.SymFloatNode) +IntLike = (int, torch.SymInt) +FloatLike = (float, torch.SymFloat) IntWithoutSymInt = int FloatWithoutSymFloat = float DeviceLikeType = Union[str, torch.device] @@ -1113,10 +1113,10 @@ class RETURN_TYPE(Enum): # TODO: when NumberType contains the sym types, can simplify this -def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type: - if isinstance(x, torch.SymIntNode): +def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type: + if isinstance(x, torch.SymInt): return int - elif isinstance(x, torch.SymFloatNode): + elif isinstance(x, torch.SymFloat): return float else: return type(x) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 652c24c9a521..c5bf346f8cb5 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -656,7 +656,7 @@ class FakeTensorMode(TorchDispatchMode): return args[0].fake_device flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs)) - flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs)) + flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs)) has_symbolic_sizes = ( any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors]) or len(flat_symints) > 0 diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 36419f20eccd..ba4090bfb684 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -59,7 +59,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) { TORCH_CHECK( !torch::jit::tracer::isTracing(), "JIT Tracing of SymInts isn't supported"); - auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); + auto py_symint = py::cast(si).release().ptr(); if (!py_symint) throw python_error(); PyTuple_SET_ITEM(ret.get(), i, py_symint); @@ -98,7 +98,7 @@ static PyObject* THPSize_pynew( if (THPUtils_checkLong(item)) { continue; } - if (torch::is_symint_node(item)) { + if (torch::is_symint(item)) { continue; } if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) { @@ -135,7 +135,7 @@ static PyObject* THPSize_repr(THPSize* self) { auto item = PyTuple_GET_ITEM(self, i); auto ih = py::handle(item); - repr += torch::is_symint_node(ih) + repr += torch::is_symint(ih) ? std::string(py::str(ih)) : std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i))); } diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 66b8ad2d8351..7e07f3ff32cd 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -2646,9 +2646,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel( "Cannot call numel on a tensor with symbolic shapes/strides"); return self->sym_numel_default(); } - return torch::is_symint_node(out) - ? out.cast()->toSymInt() - : c10::SymInt{py::cast(out)}; + return torch::is_symint(out) ? out.cast() + : c10::SymInt{py::cast(out)}; } c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( @@ -2669,9 +2668,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( if (out.is(py::none())) { return self->sym_storage_offset_default(); } - return torch::is_symint_node(out) - ? out.cast()->toSymInt() - : c10::SymInt{py::cast(out)}; + return torch::is_symint(out) ? out.cast() + : c10::SymInt{py::cast(out)}; } c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( @@ -2701,9 +2699,8 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( py::list symints; for (auto it = out.begin(); it != out.end(); it++) { auto elm = *it; - auto si = torch::is_symint_node(elm) - ? elm.cast()->toSymInt() - : c10::SymInt{py::cast(elm)}; + auto si = torch::is_symint(elm) ? elm.cast() + : c10::SymInt{py::cast(elm)}; symints.append(si.as_int_unchecked()); } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 0bb959a3c61e..91eecfa4596e 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -13,7 +13,7 @@ #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) #include #endif -#include +#include #include #include #include @@ -99,7 +99,6 @@ #include #include -#include #include #include #include @@ -126,249 +125,11 @@ using c10::Argument; using c10::FunctionSchema; using c10::SchemaArgType; using c10::SchemaArgument; -using c10::SymFloat; -using c10::SymFloatNode; -using c10::SymIntNode; +using c10::SymNode; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamWriter; using torch::utils::SchemaInfo; -static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) { - return torch::is_symint_node(b) ? b.cast() - : a->wrap(b.cast()); -} - -static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) { - if (torch::is_symfloat_node(b)) { - return b.cast(); - } else if (torch::is_symint_node(b)) { - return b.cast()->sym_float(); - } else { - return a->wrap(b.cast()); - } -} - -class PythonSymIntNodeImpl : public c10::SymIntNodeImpl { - public: - PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() { - pyobj_ = std::make_shared( - pyobj.release().ptr(), getPyInterpreter()); - }; - - virtual SymIntNode clone() override { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("clone")(); - return c10::make_intrusive(r); - } - - virtual SymIntNode wrap(int64_t num) override { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("wrap")(num); - return c10::make_intrusive(r); - } - - virtual bool bool_() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__bool__")().is(py::handle(Py_True)); - } - - virtual int64_t guard_int(const char* file, int64_t line) override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("guard_int")(file, line).cast(); - } - - virtual int64_t int_() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__int__")().cast(); - } - - SymFloatNode sym_float() override; - - virtual std::string str() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__str__")().cast(); - } - - virtual SymIntNode dispatch_common_( - const char* fname, - const SymIntNode& other) { - auto pother = dynamic_cast(other.get()); - TORCH_CHECK(pother); - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr(fname)(pother->getPyObj()); - return c10::make_intrusive(r); - } - - virtual SymIntNode dispatch_common_(const char* fname) { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr(fname)(); - return c10::make_intrusive(r); - } - - virtual SymIntNode add(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode sub(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode mul(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymFloatNode truediv(const SymIntNode& other) override; - - virtual SymIntNode floordiv(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode mod(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode eq(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode gt(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode lt(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode le(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode ge(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode min(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - virtual SymIntNode max(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode ceil() override { - return dispatch_common_(__FUNCTION__); - } - - virtual SymIntNode neg() override { - return dispatch_common_(__FUNCTION__); - } - - py::handle getPyObj() { - return py::handle(pyobj_.get()->ptr(getPyInterpreter())); - } - std::shared_ptr pyobj_ = nullptr; -}; - -class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl { - public: - PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() { - pyobj_ = std::make_shared( - pyobj.release().ptr(), getPyInterpreter()); - }; - - virtual SymFloatNode wrap(double num) override { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("wrap")(num); - return c10::make_intrusive(r); - } - - virtual std::string str() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__str__")().cast(); - } - - SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) { - auto pother = dynamic_cast(other.get()); - TORCH_CHECK(pother); - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr(fname)(pother->getPyObj()); - return c10::make_intrusive(r); - } - - SymFloatNode add(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode sub(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode mul(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode truediv(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode pow(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode eq(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode gt(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode lt(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode le(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode ge(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymIntNode ceil() override; - SymIntNode floor() override; - - py::handle getPyObj() { - return py::handle(pyobj_.get()->ptr(getPyInterpreter())); - } - std::shared_ptr pyobj_ = nullptr; -}; - -SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) { - auto pother = dynamic_cast(other.get()); - TORCH_CHECK(pother); - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("truediv")(pother->getPyObj()); - return c10::make_intrusive(r); -} - -SymFloatNode PythonSymIntNodeImpl::sym_float() { - py::gil_scoped_acquire acquire; - return c10::make_intrusive( - getPyObj().attr("__sym_float__")()); -} - -SymIntNode PythonSymFloatNodeImpl::ceil() { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("ceil")(); - return c10::make_intrusive(r); -} - -SymIntNode PythonSymFloatNodeImpl::floor() { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("floor")(); - return c10::make_intrusive(r); -} - namespace { using autograd::variable_list; @@ -1381,276 +1142,41 @@ void initJITBindings(PyObject* module) { } }); - auto symint_class = - py::class_(m, "SymIntNode") - .def_static( - "new_symint", - [](py::object obj) -> c10::SymIntNode { - return c10::make_intrusive(obj); - }) - .def( - "get_pyobj", - [](c10::SymIntNode a) -> py::object { - if (auto* psn = dynamic_cast(a.get())) { - return py::reinterpret_borrow(psn->getPyObj()); - } - return py::none(); - }) - .def( - "__add__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->add(snb); - }) - .def( - "__radd__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->add(a); - }) - .def( - "__sub__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->sub(snb); - }) - .def( - "__rsub__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->sub(a); - }) - .def( - "__mul__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->mul(snb); - }) - .def( - "__rmul__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->mul(a); - }) - .def( - "__truediv__", - [](c10::SymIntNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymIntNode(a, b); - return a->truediv(snb); - }) - .def( - "__rtruediv__", - [](c10::SymIntNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymIntNode(a, b); - return snb->truediv(a); - }) - .def( - "__floordiv__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->floordiv(snb); - }) - .def( - "__rfloordiv__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->floordiv(a); - }) - .def( - "__mod__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->mod(snb); - }) - .def( - "__rmod__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->mod(a); - }) - .def( - "__pow__", - [](c10::SymIntNode a, py::object b) -> py::object { - if (PyFloat_Check(b.ptr())) { - auto float_a = a->sym_float(); - return py::cast( - float_a->pow(float_a->wrap(py::cast(b)))); - } - // TODO: integer pow - return py::reinterpret_borrow(Py_NotImplemented); - }) - // TODO: rpow - .def( - "__eq__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->eq(snb); - }) - .def( - "__gt__", - [](c10::SymIntNode a, py::object b) { - auto snb = toSymIntNode(a, b); - return a->gt(snb); - }) - .def( - "__lt__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->lt(snb); - }) - .def( - "__le__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->le(snb); - }) - .def( - "__ge__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->ge(snb); - }) - .def( - "__ceil__", - [](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); }) - .def( - "__neg__", - [](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); }) - .def( - "__min__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->min(snb); - }) - .def( - "__max__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->max(snb); - }) - .def("__bool__", [](c10::SymIntNode a) { return a->bool_(); }) - .def("__int__", [](c10::SymIntNode a) { return a->int_(); }) - // Intentionally don't set file line, as the Python backtrace matters - // more here - .def( - "guard_int", - [](c10::SymIntNode a) { return a->guard_int(nullptr, 0); }) - .def( - "__sym_float__", - [](c10::SymIntNode a) { - // TODO: remove dynamic cast when sym_float is in base class - auto* psn = dynamic_cast(a.get()); - TORCH_INTERNAL_ASSERT(psn); - return psn->sym_float(); - }) - .def("__str__", [](c10::SymIntNode a) { return a->str(); }) - .def("__repr__", [](c10::SymIntNode a) { return a->str(); }); - - py::class_(m, "SymFloatNode") - .def_static( - "new_symfloat", - [](py::object obj) -> c10::SymFloatNode { - return c10::make_intrusive(obj); - }) - .def( - "__add__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->add(snb); - }) - .def( - "__radd__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->add(a); - }) - .def( - "__sub__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->sub(snb); - }) - .def( - "__mul__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->mul(snb); - }) - .def( - "__rmul__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->mul(a); - }) - .def( - "__truediv__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->truediv(snb); - }) - .def( - "__rtruediv__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->truediv(a); - }) - .def( - "__eq__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->eq(snb); - }) - .def( - "__gt__", - [](c10::SymFloatNode a, py::object b) { - auto snb = toSymFloatNode(a, b); - return a->gt(snb); - }) - .def( - "__lt__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->lt(snb); - }) - .def( - "__le__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->le(snb); - }) - .def( - "__ge__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->ge(snb); - }) - .def( - "__pow__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->pow(snb); - }) - .def( - "__rpow__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->pow(a); - }) - .def( - "__ceil__", - [](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); }) - .def( - "__floor__", - [](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); }) - .def( - "get_pyobj", - [](c10::SymFloatNode a) -> py::object { - if (auto* psn = dynamic_cast(a.get())) { - return py::reinterpret_borrow(psn->getPyObj()); - } - return py::none(); - }) - .def("__str__", [](c10::SymFloatNode a) { return a->str(); }); + // NB: This isn't actually used for regular PyTorch symbolic tracing; + // XLA is what needs this +#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); }) +#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); }) +#define SYMNODE_BINARY(n) \ + .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); }) + auto symnode_class = + py::class_(m, "_SymNode") + // These DO NOT install magic methods; the SymInt/SymFloat wrapper in + // Python is responsible for this + SYMNODE_UNARY(clone) + // Named these for consistency with inner python class, but maybe + // should change the python side + SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_) + SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2( + __sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub) + SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow) + SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY( + eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt) + SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min) + SYMNODE_BINARY(max) SYMNODE_UNARY(ceil) + SYMNODE_UNARY(floor) SYMNODE_UNARY(neg) + // Intentionally don't set file line, as the + // Python backtrace matters more here + .def( + "guard_int", + [](c10::SymNode a) { + return a->guard_int(nullptr, 0); + }) + .def( + "__str__", + [](c10::SymNode a) { return a->str(); }) + .def("__repr__", [](c10::SymNode a) { + return a->str(); + }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 68317f76524b..47089fcc8969 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -80,10 +80,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr())); } else if (THPUtils_checkDouble(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr())); - } else if (torch::is_symint_node(py::handle(obj))) { + } else if (torch::is_symint(py::handle(obj))) { save_symint = true; scalar = at::Scalar(7777777); - } else if (torch::is_symfloat_node(py::handle(obj))) { + } else if (torch::is_symfloat(py::handle(obj))) { save_symint = true; scalar = at::Scalar(std::numeric_limits::quiet_NaN()); } else { @@ -161,12 +161,12 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { return py::cast(obj); } case TypeKind::SymIntType: - if (torch::is_symint_node(obj.ptr())) { + if (torch::is_symint(obj.ptr())) { return py::cast(obj); } return py::cast(obj); case TypeKind::SymFloatType: - if (torch::is_symfloat_node(obj.ptr())) { + if (torch::is_symfloat(obj.ptr())) { return py::cast(obj); } return py::cast(obj); @@ -253,7 +253,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { bool is_symbolic = false; for (auto it = obj.begin(); it != obj.end(); it++) { auto elm = *it; - if (torch::is_symint_node(elm)) { + if (torch::is_symint(elm)) { is_symbolic = true; break; } @@ -269,7 +269,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { for (auto it = obj.begin(); it != obj.end(); it++) { auto elm = *it; // TODO: what about SymInt conversion to SymFloat? - if (torch::is_symfloat_node(elm)) { + if (torch::is_symfloat(elm)) { is_symbolic = true; break; } @@ -442,9 +442,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { } else if (PyComplex_CheckExact(obj.ptr())) { auto c_obj = py::cast>(obj.ptr()); return static_cast>(c_obj); - } else if (torch::is_symint_node(obj)) { + } else if (torch::is_symint(obj)) { return py::cast(obj); - } else if (torch::is_symfloat_node(obj)) { + } else if (torch::is_symfloat(obj)) { return py::cast(obj); } else { throw py::cast_error( diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 20e4730d5013..95605eab1e99 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -136,10 +136,10 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) { inline Value GetSymIntValue(c10::SymInt a) { return Value( - a.is_symbolic() ? dynamic_cast( - a.toSymIntNodeImpl().get()) - ->node_ - : MakeScalar(a.as_int_unchecked(), at::kLong), + a.is_symbolic() + ? dynamic_cast(a.toSymNodeImpl().get()) + ->node_ + : MakeScalar(a.as_int_unchecked(), at::kLong), 0); } diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index bcc73a3ed79f..df82fd45fe29 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -451,11 +451,11 @@ std::vector compute_shape_expand( std::vector target_size(_sizes.size()); for (const auto idx : c10::irange(_sizes.size())) { if (_sizes[idx].is_symbolic()) { - c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl(); - auto* lazySymIntNode = - dynamic_cast(symbolicIntNode.get()); - TORCH_INTERNAL_ASSERT(lazySymIntNode); - auto size_node = lazySymIntNode->node_; + c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl(); + auto* lazySymNode = + dynamic_cast(symbolicIntNode.get()); + TORCH_INTERNAL_ASSERT(lazySymNode); + auto size_node = lazySymNode->node_; auto static_value = std::dynamic_pointer_cast(size_node) ->getStaticValue(); diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index a1b51495fb3f..9ceb45d6b23d 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/torch/csrc/lazy/core/tensor.h b/torch/csrc/lazy/core/tensor.h index 85ea6ab4f4c6..8dfa5a077c97 100644 --- a/torch/csrc/lazy/core/tensor.h +++ b/torch/csrc/lazy/core/tensor.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -10,12 +10,9 @@ namespace torch { namespace lazy { -class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl { +class TORCH_API SymNodeImpl : public c10::SymNodeImpl { public: - SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; - c10::SymIntNode add(const c10::SymIntNode& other) override { - TORCH_CHECK(false, "NYI"); - } + SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; NodePtr node_; }; diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index f338d3f196ad..f03763f9dca3 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -685,7 +685,7 @@ static bool is_int_list( // NB: do NOT check that the later arguments are ints, as this is // BC-breaking for FX for (int i = 1; i < len; i++) { - if (torch::is_symint_node( + if (torch::is_symint( py::reinterpret_steal(PySequence_GetItem(obj, i)))) { if (failed_idx != nullptr) { *failed_idx = i; @@ -716,9 +716,9 @@ static bool is_int_list( static bool is_int_or_symint(PyObject* obj) { // THPUtils_checkIndex may call __index__ or __int__ // which may have side effects if obj is a symint node - // so we do `is_symint_node` check first + // so we do `is_symint` check first // TODO: maybe we should be using checkLong here? - return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj); + return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj); } static bool is_int_or_symint_list( @@ -1570,13 +1570,13 @@ at::Tensor PythonArgs::tensor_slow(int i) { // NB: we DO NOT put symbolic ints/floats into the Scalar itself, // because although Scalar supports SymInt/SymFloat, the subsequent // conversion to Tensor does not. Instead, do it out of band. - } else if (torch::is_symint_node(py::handle(obj))) { + } else if (torch::is_symint(py::handle(obj))) { save_symint = true; // This scalar value doesn't matter, it shouldn't ever actually // get read out. Make it a big and weird looking number to help // people figure out if there's aproblem. scalar = at::Scalar(7777777); - } else if (torch::is_symfloat_node(py::handle(obj))) { + } else if (torch::is_symfloat(py::handle(obj))) { save_symint = true; scalar = at::Scalar(std::numeric_limits::quiet_NaN()); } else { @@ -1633,11 +1633,11 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) { return at::Scalar(THPUtils_unpackComplexDouble(arg)); } - if (torch::is_symint_node(arg)) { + if (torch::is_symint(arg)) { return at::Scalar(py::cast(arg)); } - if (torch::is_symfloat_node(arg)) { + if (torch::is_symfloat(arg)) { return at::Scalar(py::cast(arg)); } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index acb830addf8f..df084821ba25 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -61,6 +61,7 @@ #include #include #include +#include #include #include @@ -69,7 +70,7 @@ #include #include -#include +#include #include #include @@ -78,30 +79,6 @@ #include #include -namespace torch { - -inline bool is_symint_node(py::handle obj) { - auto static tp_symn = py::type::of(); - if (py::isinstance(obj, tp_symn)) { - TORCH_CHECK( - !jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!"); - return true; - } - return false; -} - -inline bool is_symfloat_node(py::handle obj) { - auto static tp_symn = py::type::of(); - if (py::isinstance(obj, tp_symn)) { - TORCH_CHECK( - !jit::tracer::isTracing(), "JIT tracing of SymFloats isn't supported!"); - return true; - } - return false; -} - -} // namespace torch - namespace pybind11 { namespace detail { template <> @@ -109,8 +86,10 @@ struct type_caster { public: PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt")); bool load(py::handle src, bool) { - if (torch::is_symint_node(src)) { - value = src.cast()->toSymInt(); + if (torch::is_symint(src)) { + value = c10::SymInt(static_cast( + c10::make_intrusive( + src.attr("node")))); return true; } @@ -126,8 +105,15 @@ struct type_caster { c10::SymInt si, return_value_policy /* policy */, handle /* parent */) { - return si.is_symbolic() ? py::cast(si.toSymIntNodeImpl()).release() - : py::cast(si.expect_int()).release(); + if (si.is_symbolic()) { + // TODO: generalize this to work with C++ backed class + auto* py_node = dynamic_cast( + si.toSymNodeImpl().get()); + TORCH_INTERNAL_ASSERT(py_node); + return torch::get_symint_class()(py_node->getPyObj()).release(); + } else { + return py::cast(si.as_int_unchecked()).release(); + } } }; @@ -136,8 +122,10 @@ struct type_caster { public: PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat")); bool load(py::handle src, bool) { - if (torch::is_symfloat_node(src)) { - value = src.cast()->toSymFloat(); + if (torch::is_symfloat(src)) { + value = c10::SymFloat(static_cast( + c10::make_intrusive( + src.attr("node")))); return true; } @@ -153,8 +141,15 @@ struct type_caster { c10::SymFloat si, return_value_policy /* policy */, handle /* parent */) { - return si.is_symbolic() ? py::cast(si.toSymFloatNodeImpl()).release() - : py::cast(si.expect_float()).release(); + if (si.is_symbolic()) { + // TODO: generalize this to work with C++ backed class + auto* py_node = dynamic_cast( + si.toSymNodeImpl().get()); + TORCH_INTERNAL_ASSERT(py_node); + return torch::get_symfloat_class()(py_node->getPyObj()).release(); + } else { + return py::cast(si.as_float_unchecked()).release(); + } } }; } // namespace detail @@ -167,8 +162,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) { } #endif return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) || - torch::is_symint_node(py::handle(obj)) || - torch::is_symfloat_node(py::handle(obj)); + torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj)); } namespace torch { @@ -574,7 +568,7 @@ inline std::vector PythonArgs::intlist(int i) { inline PyObject* toPyObject(c10::SymInt symint) { if (symint.is_symbolic()) { - auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr(); + auto r = py::cast(symint).release().ptr(); TORCH_INTERNAL_ASSERT(r); return r; } else { @@ -609,8 +603,8 @@ inline std::vector PythonArgs::symintlist(int i) { size1, c10::SymInt(THPUtils_unpackIndex(args[i]))); } - if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) { - auto si = py::handle(args[i]).cast()->toSymInt(); + if (size1 > 0 && torch::is_symint(py::handle(args[i]))) { + auto si = py::handle(args[i]).cast(); return std::vector(size1, si); } @@ -652,9 +646,8 @@ inline std::vector PythonArgs::symintlist(int i) { res.push_back(var.item()); } else { try { - if (is_symint_node(py::handle(obj))) { - res.push_back( - py::handle(obj).cast()->toSymInt()); + if (is_symint(py::handle(obj))) { + res.push_back(py::handle(obj).cast()); } else { res.push_back(c10::SymInt(THPUtils_unpackIndex(obj))); } diff --git a/torch/csrc/utils/python_symnode.cpp b/torch/csrc/utils/python_symnode.cpp new file mode 100644 index 000000000000..318bb2266aa4 --- /dev/null +++ b/torch/csrc/utils/python_symnode.cpp @@ -0,0 +1,19 @@ +#include + +namespace torch { + +py::handle get_symint_class() { + // NB: leak + static py::handle symint_class = + py::object(py::module::import("torch").attr("SymInt")).release(); + return symint_class; +} + +py::handle get_symfloat_class() { + // NB: leak + static py::handle symfloat_class = + py::object(py::module::import("torch").attr("SymFloat")).release(); + return symfloat_class; +} + +} // namespace torch diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h new file mode 100644 index 000000000000..be402e4d5439 --- /dev/null +++ b/torch/csrc/utils/python_symnode.h @@ -0,0 +1,182 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { + +TORCH_PYTHON_API py::handle get_symint_class(); +TORCH_PYTHON_API py::handle get_symfloat_class(); + +// NB: These functions must not be called too early, otherwise torch not setup. +// Alternate design is to have torch "register" the object to us +inline bool is_symint(py::handle obj) { + return py::isinstance(obj, get_symint_class()); +} +inline bool is_symfloat(py::handle obj) { + return py::isinstance(obj, get_symfloat_class()); +} + +namespace impl { + +// This c10::SymNodeImpl simply backends to a Python object that +// implements the API. The Python object is the source of truth, +// this is just an adapter so C++ calls can get to the object. +class PythonSymNodeImpl : public c10::SymNodeImpl { + public: + PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() { + pyobj_ = std::make_shared( + pyobj.release().ptr(), getPyInterpreter()); + }; + + c10::SymNode wrap_int(int64_t num) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("wrap_int")(num); + return c10::make_intrusive(r); + } + + c10::SymNode wrap_float(double num) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("wrap_float")(num); + return c10::make_intrusive(r); + } + + bool bool_() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("bool_")().is(py::handle(Py_True)); + } + + bool is_int() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_int")().is(py::handle(Py_True)); + } + + bool is_float() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_float")().is(py::handle(Py_True)); + } + + int64_t guard_int(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_int")(file, line).cast(); + } + + double guard_float(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_float")(file, line).cast(); + } + + int64_t int_() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("int_")().cast(); + } + + std::string str() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("str")().cast(); + } + + c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) { + auto pother = dynamic_cast(other.get()); + TORCH_CHECK(pother); + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr(fname)(pother->getPyObj()); + return c10::make_intrusive(r); + } + + c10::SymNode dispatch_common_(const char* fname) { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr(fname)(); + return c10::make_intrusive(r); + } + + c10::SymNode add(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode sub(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode mul(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode truediv(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode pow(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode floordiv(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode mod(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode eq(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode gt(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode lt(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode le(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode ge(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode min(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + c10::SymNode max(const c10::SymNode& other) override { + return dispatch_common_(__FUNCTION__, other); + } + + c10::SymNode ceil() override { + return dispatch_common_(__FUNCTION__); + } + + c10::SymNode floor() override { + return dispatch_common_(__FUNCTION__); + } + + c10::SymNode neg() override { + return dispatch_common_(__FUNCTION__); + } + + c10::SymNode clone() override { + return dispatch_common_(__FUNCTION__); + } + + c10::SymNode sym_int() override { + return dispatch_common_(__FUNCTION__); + } + + c10::SymNode sym_float() override { + return dispatch_common_(__FUNCTION__); + } + + py::handle getPyObj() { + return py::handle(pyobj_.get()->ptr(getPyInterpreter())); + } + std::shared_ptr pyobj_ = nullptr; +}; + +} // namespace impl +} // namespace torch diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 86d1e1955092..c83560754890 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -21,8 +21,9 @@ import operator from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode from torch._subclasses import FakeTensor -from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat +from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode from torch.fx import Proxy +from torch import SymInt, SymFloat __all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"] aten = torch.ops.aten @@ -55,27 +56,27 @@ def decompose(decomposition_table): proxy_slot = object() no_default = object() -py_sym_types = ( - PySymInt, - PySymFloat, -) +py_sym_types = (SymInt, SymFloat) def is_sym_node(node): assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta" return "val" in node.meta and isinstance(node.meta['val'], py_sym_types) def set_proxy_slot(obj, tracer, proxy): - d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) + assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) + d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload] assert isinstance(d, weakref.WeakKeyDictionary) d[tracer] = proxy def has_proxy_slot(obj, tracer): + assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) return get_proxy_slot(obj, tracer, False, lambda _: True) # the default argument is what to return if the slot is not set. # the transform argument is handy if you need to extract a subfield from # the successfully looked up result (but NOT the default.) def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): - d = obj.__dict__.get(proxy_slot) + assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) + d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload] if not d: if default is no_default: raise KeyError(f"{obj} is not tracked with proxy for {tracer}") @@ -130,10 +131,8 @@ def track_tensor(tensor, proxy, *, constant, tracer): def try_set_proxy_slot(outer_s, proxy_callable, *args): assert callable(proxy_callable) if isinstance(outer_s, SymInt): - inner_s = outer_s.get_pyobj() - assert isinstance(inner_s, py_sym_types) - - set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args)) + inner_s = outer_s.node + set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, outer_s, *args)) # The basic idea is that we need to associate each tensor/SymInt # with a Proxy. How do we setup this association? We just store @@ -198,7 +197,7 @@ class _ProxyTensor: def fetch_sym_proxy(tracer): def inner(e): - n = e.get_pyobj() + n = e.node if n.constant is not None: return n.constant else: @@ -400,8 +399,8 @@ class PythonKeyTracer(Tracer): return self.create_node('get_attr', qualname, (), {}) elif isinstance(a, (SymInt, SymFloat)): - assert a.get_pyobj().constant is not None - return a.get_pyobj().constant + assert a.node.constant is not None + return a.node.constant return super().create_arg(a) @@ -432,7 +431,7 @@ def wrap_key(f, tensors, tracer): ) out = pytree.tree_map_only( (SymInt, SymFloat), - lambda t: get_proxy_slot(t.get_pyobj(), tracer)(), + lambda t: get_proxy_slot(t.node, tracer)(), out ) return out @@ -479,10 +478,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode): return out -SymInt = torch.SymIntNode -SymFloat = torch.SymFloatNode - - class ProxySymDispatchMode(SymDispatchMode): def __init__(self, tracer): super().__init__() @@ -501,10 +496,9 @@ class ProxySymDispatchMode(SymDispatchMode): finally: self.enable_tracing = old - def _compute_proxy(self, func, args, out): + def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat]): n_args = tuple( - get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant - if isinstance(a, py_sym_types) else a + get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a for a in args ) @@ -520,10 +514,11 @@ class ProxySymDispatchMode(SymDispatchMode): return func(*args, **kwargs) # Peephole optimize multiply by one + # NB: be careful not to trigger guards here! if func == operator.mul: - if isinstance(args[1], (PySymInt, PySymFloat)) and args[1].constant == 1: + if isinstance(args[1], int) and args[1] == 1: return args[0] - elif isinstance(args[0], (PySymInt, PySymFloat)) and args[0].constant == 1: + elif isinstance(args[0], int) and args[0] == 1: return args[1] # For speed, we assume there are no nested data structures @@ -535,7 +530,7 @@ class ProxySymDispatchMode(SymDispatchMode): # Delays tracing out the proxies on this op until we actually need it p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) - set_proxy_slot(out, self.tracer, p_out_thunk) + set_proxy_slot(out.node, self.tracer, p_out_thunk) return out diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0a03e5819a90..2eb169a0d188 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -10,6 +10,7 @@ import traceback import collections import textwrap from torch._subclasses.meta_utils import MetaConverter +from torch import SymInt, SymFloat try: import sympy # type: ignore[import] @@ -21,8 +22,8 @@ except ImportError: aten = torch.ops.aten # type: ignore[has-type] __all__ = [ - "has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv", - "SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv" + "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", + "SymDispatchMode", "sym_float", "FloorDiv", "guard_int", "wrap_node" ] SYM_FUNCTION_MODE = None @@ -88,32 +89,38 @@ def _handle_sym_dispatch(func, args, kwargs): finally: SYM_FUNCTION_MODE = mode +def guard_int(a): + if isinstance(a, SymInt): + return a.node.guard_int("", 0) # NB: uses Python backtrace + assert isinstance(a, int) + return a + def sym_float(a): - if hasattr(a, '__sym_float__'): - return a.__sym_float__() - elif isinstance(a, torch._C.SymFloatNode): + if isinstance(a, SymFloat): return a + elif hasattr(a, '__sym_float__'): + return a.__sym_float__() return float(a) def sym_int(a): - if hasattr(a, '__sym_int__'): - return a.__sym_int__() - elif isinstance(a, torch._C.SymIntNode): + if isinstance(a, SymInt): return a + elif hasattr(a, '__sym_int__'): + return a.__sym_int__() return int(a) # TODO: An incomplete list # 1. Set variables to be equal when we do equality # 2. Specialize on 0/1 when we do subtraction -class PySymInt(object): +class SymNode: """ - PySymInt objects are the primary "symbolic shape" objects that flow through - our program. They're what sit under FakeTensor, and contains our primary - implementation of symbolic shapes. + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. """ - def __init__(self, expr, shape_env, constant=None): + def __init__(self, expr, shape_env, pytype, constant=None): self._expr = expr self.shape_env = shape_env + self.pytype = pytype self.constant = constant @property @@ -121,23 +128,49 @@ class PySymInt(object): self._update_expr() return self._expr - def wrap(self, num): - return PySymInt(sympy.Integer(num), self.shape_env, constant=num) - - def clone(self): - return PySymInt(self.expr, self.shape_env, constant=self.constant) - def _update_expr(self): self._expr = self.shape_env.replace(self._expr) - def __str__(self): + def to_node(self, num): + if isinstance(num, (SymInt, SymFloat)): + return num.node + elif isinstance(num, int): + return self.wrap_int(num) + elif isinstance(num, float): + return self.wrap_float(num) + else: + # NotImplementedError is important so that Python tries the + # other magic method + raise NotImplementedError(type(num)) + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def wrap_int(self, num): + assert isinstance(num, int) + return SymNode(sympy.Integer(num), self.shape_env, int, constant=num) + + def wrap_float(self, num): + assert isinstance(num, float) + return SymNode(sympy.Integer(num), self.shape_env, float, constant=num) + + def clone(self): + return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant) + + def str(self): return f"{self.expr}" + def __str__(self): + return self.str() + def __repr__(self): - return f"{self.expr}" + return self.str() # Today we error on calling int on a symbolic shape, as this is a very accessible footgun. - def __int__(self): + def int_(self): raise RuntimeError("Trying to extract a concrete int out of a symbolic int") # You can manually trigger a guard with this function @@ -146,28 +179,35 @@ class PySymInt(object): # guard occurred return int(self.shape_env.evaluate_expr(self.expr)) - def __sym_float__(self): + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + return float(self.shape_env.evaluate_expr(self.expr)) + + def sym_float(self): if SYM_FUNCTION_MODE: - return _handle_sym_dispatch(sym_float, (self,), {}) + r = _handle_sym_dispatch(sym_float, (wrap_node(self),), {}) + assert isinstance(r, (SymInt, SymFloat)), type(r) + return r.node # TODO: consider constant prop here # TODO: wrapping the expr with sympy.Float doesn't seem to work, why # not? - return PySymFloat(self.expr, self.shape_env) + return SymNode(self.expr, self.shape_env, float) - def __bool__(self): + def sym_int(self): + raise NotImplementedError("sym_int NYI") + """ + if SYM_FUNCTION_MODE: + return _handle_sym_dispatch(sym_int, (self,), {}) + # TODO: consider constant prop here + # XXX: need to cast float to int in sympy; math.floor is wrong + # because negatives round to zero + return SymNode(self.expr, self.shape_env, int) + """ + + def bool_(self): return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))) -class PySymFloat: - def __init__(self, expr, shape_env, constant=None): - self.expr = expr - self.shape_env = shape_env - self.constant = constant - - def wrap(self, num): - return PySymFloat(sympy.Float(num), self.shape_env, constant=num) - - def __str__(self): - return f"{self.expr}" if HAS_SYMPY: class FloorDiv(sympy.Function): @@ -238,32 +278,45 @@ unary_magic_methods = { float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"} -def _make_magic(method, func, py_type): +def wrap_node(x): + if not isinstance(x, SymNode): + return x + if x.constant is not None: + return x.constant + if x.pytype is int: + return SymInt(x) + elif x.pytype is float: + return SymFloat(x) + else: + raise AssertionError(f"unrecognized return type {x.pytype}") + +def _make_node_magic(method, func): func = lru_cache(256)(func) - def magic_impl(self, other): + def binary_magic_impl(self, other): if method in ["min", "max"]: op = getattr(builtins, method) else: op = getattr(operator, method) if SYM_FUNCTION_MODE: - return _handle_sym_dispatch(op, (self, other), {}) - if isinstance(other, py_type): - other_expr = other.expr - else: - assert isinstance(other, sympy.Expr) - other_expr = other + r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + assert isinstance(r, (SymInt, SymFloat)), type(r) + return r.node + assert isinstance(other, SymNode) + other_expr = other.expr # TODO: consider constant prop here expr = self.shape_env.replace(self.expr) other_expr = self.shape_env.replace(other_expr) out = func(expr, other_expr) out = sympy.expand(out) if method in ["truediv"]: - return PySymFloat(out, self.shape_env) + pytype = float else: - # TODO: relational operators actually technically return a - # PySymBool, this is a type error - return py_type(out, self.shape_env) + pytype = self.pytype + + # TODO: relational operators actually technically return a + # PySymBool, this is a type error + return SymNode(out, self.shape_env, pytype) def unary_magic_impl(self): if SYM_FUNCTION_MODE: @@ -271,33 +324,55 @@ def _make_magic(method, func, py_type): op = getattr(math, method) else: op = getattr(operator, method) - return _handle_sym_dispatch(op, (self,), {}) + r = _handle_sym_dispatch(op, (wrap_node(self),), {}) + assert isinstance(r, (SymInt, SymFloat)), type(r) + return r.node # TODO: consider constant prop here expr = self.shape_env.replace(self.expr) out = func(expr) out = sympy.expand(out) if method in ["ceil", "floor"]: - return PySymInt(out, self.shape_env) + pytype = int else: - return py_type(out, self.shape_env) + pytype = self.pytype + + return SymNode(out, self.shape_env, pytype) - # this should be wrapped transparently into torch.SymIntNode if method in unary_magic_methods: - setattr(py_type, method, unary_magic_impl) - setattr(py_type, f"__{method}__", unary_magic_impl) + setattr(SymNode, method, unary_magic_impl) else: - setattr(py_type, method, magic_impl) - setattr(py_type, f"__{method}__", magic_impl) - if method in reflectable_magic_methods: - setattr(py_type, f"__r{method}__", magic_impl) + setattr(SymNode, method, binary_magic_impl) for method, func in magic_methods.items(): - _make_magic(method, func, PySymInt) + _make_node_magic(method, func) + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + + def unary_magic_impl(self): + return wrap_node(getattr(self.node, method)()) + + def binary_magic_impl(self, other): + return wrap_node(getattr(self.node, method)(self.node.to_node(other))) + + def rbinary_magic_impl(self, other): + return wrap_node(getattr(self.node.to_node(other), method)(self.node)) + + if method in unary_magic_methods: + setattr(user_type, f"__{method}__", unary_magic_impl) + else: + setattr(user_type, f"__{method}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattr(user_type, f"__r{method}__", rbinary_magic_impl) + +for method, func in magic_methods.items(): + _make_user_magic(method, SymInt) for method, func in magic_methods.items(): if method not in float_magic_methods: continue - _make_magic(method, func, PySymFloat) + _make_user_magic(method, SymFloat) del method del func @@ -390,9 +465,7 @@ class ShapeEnv(object): return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type] def create_symintnode(self, expr: "sympy.Expr"): - py_sym_int = PySymInt(expr, self) - cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] - return cpp_sym_int + return SymInt(SymNode(expr, self, int)) def create_symbol(self, val: int) -> "sympy.Expr": if not HAS_SYMPY: diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 3b8c96b6a43b..4fdd64f900a9 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -498,7 +498,7 @@ class CodeGen(object): if isinstance(meta_val, FakeTensor): maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}' elif isinstance(meta_val, py_sym_types): - maybe_type_annotation = f': Sym({meta_val.expr})' + maybe_type_annotation = f': Sym({meta_val})' elif isinstance(meta_val, TensorMetadata): maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'