mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Unify SymIntNode and SymFloatNode into SymNode (#87817)
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817 Approved by: https://github.com/albanD, https://github.com/anjali411
This commit is contained in:
committed by
PyTorch MergeBot
parent
2205f56f46
commit
1ff52225f1
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
095ee628212f0235ad0d6908bdd514123639fc86
|
||||
1e9b8bdc75114ac6c16305c970be37a1cd2fdb1c
|
||||
|
@ -439,7 +439,7 @@ command = [
|
||||
"""--error-description=\
|
||||
This line has an isinstance call that directly refers to \
|
||||
int or float. This is error-prone because you may also \
|
||||
have wanted to allow SymIntNode or SymFloatNode in your test. \
|
||||
have wanted to allow SymInt or SymFloat in your test. \
|
||||
To suppress this lint, use an appropriate type alias defined \
|
||||
in torch._prims_common; use IntLike/FloatLike when you would accept \
|
||||
both regular and symbolic numbers, Dim for ints representing \
|
||||
|
@ -95,7 +95,7 @@ c10::SymInt get_nbytes(const Tensor& value) {
|
||||
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
|
||||
// Today, the two implementations of SymInt are in Python (proxy tensor),
|
||||
// and lazy tensor (LTC/XLA).
|
||||
// LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl).
|
||||
// LTC hasn't implemented SymInt support yet though
|
||||
// Once it does, we should remove this check.
|
||||
if (value.key_set().has(c10::DispatchKey::Python)) {
|
||||
return value.storage().sym_nbytes();
|
||||
|
@ -562,7 +562,7 @@ public:
|
||||
IValue(c10::SymInt i) {
|
||||
if (i.is_symbolic()) {
|
||||
tag = Tag::SymInt;
|
||||
payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release();
|
||||
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||
} else {
|
||||
tag = Tag::Int;
|
||||
payload.u.as_int = i.as_int_unchecked();
|
||||
@ -578,7 +578,7 @@ public:
|
||||
IValue(c10::SymFloat i) {
|
||||
if (i.is_symbolic()) {
|
||||
tag = Tag::SymFloat;
|
||||
payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release();
|
||||
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||
} else {
|
||||
tag = Tag::Double;
|
||||
payload.u.as_double = i.as_float_unchecked();
|
||||
@ -812,10 +812,10 @@ public:
|
||||
// for both SymFloat and double
|
||||
if (s.isSymInt()) {
|
||||
tag = Tag::SymInt;
|
||||
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release();
|
||||
payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
|
||||
} else if (s.isSymFloat()) {
|
||||
tag = Tag::SymFloat;
|
||||
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release();
|
||||
payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
|
||||
} else if (s.isFloatingPoint()) {
|
||||
tag = Tag::Double;
|
||||
payload.u.as_double = s.toDouble();
|
||||
|
@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& {
|
||||
inline c10::SymInt IValue::toSymInt() const {
|
||||
AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
|
||||
if (isSymInt()) {
|
||||
return c10::SymInt::toSymInt(toIntrusivePtr<c10::SymIntNodeImpl>());
|
||||
return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
|
||||
} else {
|
||||
return c10::SymInt(payload.u.as_int);
|
||||
}
|
||||
@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const {
|
||||
inline c10::SymFloat IValue::toSymFloat() const {
|
||||
AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
|
||||
if (isSymFloat()) {
|
||||
return c10::SymFloat::toSymFloat(toIntrusivePtr<c10::SymFloatNodeImpl>());
|
||||
return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
|
||||
} else {
|
||||
return c10::SymFloat(payload.u.as_double);
|
||||
}
|
||||
|
@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type {
|
||||
return "SymInt";
|
||||
}
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
// TODO: will become a Union[SymIntNodeImpl|int] in the near future
|
||||
return "int";
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::SymIntType;
|
||||
|
@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) {
|
||||
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
|
||||
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
|
||||
}
|
||||
|
||||
TEST(TestSymInt, Basic) {
|
||||
Scalar foo;
|
||||
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
|
||||
foo = Scalar(a_impl->toSymInt());
|
||||
ASSERT_EQ(a_impl.use_count(), 2);
|
||||
Scalar bar{foo};
|
||||
ASSERT_EQ(a_impl.use_count(), 3);
|
||||
auto baz = bar;
|
||||
ASSERT_EQ(a_impl.use_count(), 4);
|
||||
auto foo2 = std::move(bar);
|
||||
ASSERT_EQ(a_impl.use_count(), 4);
|
||||
ASSERT_TRUE(foo2.isSymInt());
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
||||
ASSERT_TRUE(bar.isIntegral(false));
|
||||
foo2 = SymInt(4);
|
||||
ASSERT_FALSE(foo2.isSymInt());
|
||||
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
|
||||
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
|
||||
foo2 = foo2;
|
||||
ASSERT_FALSE(foo2.isSymInt());
|
||||
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
|
||||
|
||||
ASSERT_EQ(a_impl.use_count(), 3);
|
||||
|
||||
ASSERT_THROW(foo.to<double>(), c10::Error);
|
||||
|
||||
Scalar int_s = 3;
|
||||
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
|
||||
|
||||
}
|
||||
|
@ -958,6 +958,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/object_ptr.cpp",
|
||||
"torch/csrc/utils/python_arg_parser.cpp",
|
||||
"torch/csrc/utils/python_dispatch.cpp",
|
||||
"torch/csrc/utils/python_symnode.cpp",
|
||||
"torch/csrc/utils/structseq.cpp",
|
||||
"torch/csrc/utils/tensor_apply.cpp",
|
||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||
|
@ -92,8 +92,8 @@ class C10_API Scalar {
|
||||
|
||||
SymInt toSymInt() const {
|
||||
if (Tag::HAS_si == tag) {
|
||||
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy(
|
||||
static_cast<SymIntNodeImpl*>(v.p)));
|
||||
return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
||||
static_cast<SymNodeImpl*>(v.p)));
|
||||
} else {
|
||||
return toLong();
|
||||
}
|
||||
@ -101,9 +101,8 @@ class C10_API Scalar {
|
||||
|
||||
SymFloat toSymFloat() const {
|
||||
if (Tag::HAS_sd == tag) {
|
||||
return c10::SymFloat::toSymFloat(
|
||||
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy(
|
||||
static_cast<SymFloatNodeImpl*>(v.p)));
|
||||
return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
||||
static_cast<SymNodeImpl*>(v.p)));
|
||||
} else {
|
||||
return toDouble();
|
||||
}
|
||||
|
@ -1,32 +1,27 @@
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymFloatNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <array>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
SymFloatNode SymFloat::toSymFloatNodeImpl() const {
|
||||
SymNode SymFloat::toSymNodeImpl() const {
|
||||
TORCH_CHECK(is_symbolic());
|
||||
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
|
||||
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
||||
}
|
||||
|
||||
static std::array<SymFloatNode, 2> normalize_symfloats(
|
||||
SymFloat a_,
|
||||
SymFloat b_) {
|
||||
SymFloatNode a, b;
|
||||
static std::array<SymNode, 2> normalize_symfloats(SymFloat a_, SymFloat b_) {
|
||||
SymNode a, b;
|
||||
if (a_.is_symbolic())
|
||||
a = a_.toSymFloatNodeImpl();
|
||||
a = a_.toSymNodeImpl();
|
||||
if (b_.is_symbolic())
|
||||
b = b_.toSymFloatNodeImpl();
|
||||
b = b_.toSymNodeImpl();
|
||||
|
||||
SymFloatNodeImpl* common = a ? a.get() : b.get();
|
||||
// TODO: technically we need to check that the classes match
|
||||
SymNodeImpl* common = a ? a.get() : b.get();
|
||||
if (!a) {
|
||||
a = common->wrap(a_.as_float_unchecked());
|
||||
a_.toSymFloat(a); //
|
||||
a = common->wrap_float(a_.as_float_unchecked());
|
||||
}
|
||||
if (!b) {
|
||||
b = common->wrap(b_.as_float_unchecked());
|
||||
b_.toSymFloat(b);
|
||||
b = common->wrap_float(b_.as_float_unchecked());
|
||||
}
|
||||
return {a, b};
|
||||
}
|
||||
@ -36,7 +31,7 @@ SymFloat SymFloat::operator+(SymFloat sci) const {
|
||||
return SymFloat(data_ + sci.data_);
|
||||
}
|
||||
auto res = normalize_symfloats(*this, sci);
|
||||
return SymFloat::toSymFloat(res[0]->add(res[1]));
|
||||
return SymFloat(res[0]->add(res[1]));
|
||||
}
|
||||
|
||||
SymFloat SymFloat::operator-(SymFloat sci) const {
|
||||
@ -44,7 +39,7 @@ SymFloat SymFloat::operator-(SymFloat sci) const {
|
||||
return SymFloat(data_ - sci.data_);
|
||||
}
|
||||
auto res = normalize_symfloats(*this, sci);
|
||||
return SymFloat::toSymFloat(res[0]->sub(res[1]));
|
||||
return SymFloat(res[0]->sub(res[1]));
|
||||
}
|
||||
|
||||
SymFloat SymFloat::operator*(SymFloat sci) const {
|
||||
@ -52,7 +47,7 @@ SymFloat SymFloat::operator*(SymFloat sci) const {
|
||||
return SymFloat(data_ * sci.data_);
|
||||
}
|
||||
auto res = normalize_symfloats(*this, sci);
|
||||
return SymFloat::toSymFloat(res[0]->mul(res[1]));
|
||||
return SymFloat(res[0]->mul(res[1]));
|
||||
}
|
||||
|
||||
SymFloat SymFloat::operator/(SymFloat sci) const {
|
||||
@ -60,16 +55,12 @@ SymFloat SymFloat::operator/(SymFloat sci) const {
|
||||
return SymFloat(data_ / sci.data_);
|
||||
}
|
||||
auto res = normalize_symfloats(*this, sci);
|
||||
return SymFloat::toSymFloat(res[0]->truediv(res[1]));
|
||||
}
|
||||
|
||||
c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
|
||||
return c10::SymFloat(std::move(sin_sp));
|
||||
return SymFloat(res[0]->truediv(res[1]));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, SymFloat s) {
|
||||
if (s.is_symbolic()) {
|
||||
os << s.toSymFloatNodeImpl()->str();
|
||||
os << s.toSymNodeImpl()->str();
|
||||
} else {
|
||||
os << s.as_float_unchecked();
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SymFloatNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
@ -14,20 +14,21 @@ namespace c10 {
|
||||
class C10_API SymFloat {
|
||||
public:
|
||||
/*implicit*/ SymFloat(double d) : data_(d){};
|
||||
SymFloat(SymFloatNode ptr)
|
||||
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)){};
|
||||
SymFloat(SymNode ptr)
|
||||
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
|
||||
TORCH_CHECK(ptr_->is_float());
|
||||
};
|
||||
SymFloat() : data_(0.0) {}
|
||||
|
||||
SymFloatNodeImpl* toSymFloatNodeImplUnowned() const {
|
||||
SymNodeImpl* toSymNodeImplUnowned() const {
|
||||
return ptr_.get();
|
||||
}
|
||||
|
||||
SymFloatNodeImpl* release() && {
|
||||
SymNodeImpl* release() && {
|
||||
return std::move(ptr_).release();
|
||||
}
|
||||
|
||||
SymFloatNode toSymFloatNodeImpl() const;
|
||||
static c10::SymFloat toSymFloat(SymFloatNode sin);
|
||||
SymNode toSymNodeImpl() const;
|
||||
|
||||
double expect_float() const {
|
||||
TORCH_CHECK(!is_symbolic());
|
||||
@ -53,7 +54,7 @@ class C10_API SymFloat {
|
||||
private:
|
||||
// TODO: optimize to union
|
||||
double data_;
|
||||
SymFloatNode ptr_;
|
||||
SymNode ptr_;
|
||||
};
|
||||
|
||||
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);
|
||||
|
@ -1,20 +0,0 @@
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymFloatNodeImpl.h>
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
c10::SymFloat SymFloatNodeImpl::toSymFloat() {
|
||||
auto sit_sp = SymFloatNode::reclaim_copy(this);
|
||||
return SymFloat::toSymFloat(sit_sp);
|
||||
}
|
||||
|
||||
c10::SymIntNode SymFloatNodeImpl::ceil() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
|
||||
c10::SymIntNode SymFloatNodeImpl::floor() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
|
||||
} // namespace c10
|
@ -1,76 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class SymIntNodeImpl;
|
||||
using SymIntNode = c10::intrusive_ptr<SymIntNodeImpl>;
|
||||
|
||||
class SymFloat;
|
||||
class SymFloatNodeImpl;
|
||||
using SymFloatNode = c10::intrusive_ptr<SymFloatNodeImpl>;
|
||||
|
||||
class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
c10::SymFloat toSymFloat();
|
||||
virtual ~SymFloatNodeImpl(){};
|
||||
|
||||
template <typename T>
|
||||
c10::intrusive_ptr<T> dyn_cast() const {
|
||||
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
|
||||
}
|
||||
|
||||
virtual SymFloatNode wrap(double num) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode add(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymFloatNode sub(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymFloatNode mul(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymFloatNode truediv(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymFloatNode pow(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymFloatNode eq(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode ne(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode gt(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode lt(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode le(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode ge(const SymFloatNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode ceil();
|
||||
virtual SymIntNode floor();
|
||||
virtual std::string str() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os) {
|
||||
os << str();
|
||||
return os;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace c10
|
@ -1,47 +1,46 @@
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <array>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
static std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) {
|
||||
SymIntNode a, b;
|
||||
static std::array<SymNode, 2> normalize_symints(SymInt a_, SymInt b_) {
|
||||
SymNode a, b;
|
||||
if (a_.is_symbolic())
|
||||
a = a_.toSymIntNodeImpl();
|
||||
a = a_.toSymNodeImpl();
|
||||
if (b_.is_symbolic())
|
||||
b = b_.toSymIntNodeImpl();
|
||||
b = b_.toSymNodeImpl();
|
||||
|
||||
SymIntNodeImpl* common = a ? a.get() : b.get();
|
||||
SymNodeImpl* common = a ? a.get() : b.get();
|
||||
// TODO: technically we need to check that the classes match
|
||||
if (!a) {
|
||||
a = common->wrap(a_.as_int_unchecked());
|
||||
a_.toSymInt(a); //
|
||||
a = common->wrap_int(a_.as_int_unchecked());
|
||||
}
|
||||
if (!b) {
|
||||
b = common->wrap(b_.as_int_unchecked());
|
||||
b_.toSymInt(b);
|
||||
b = common->wrap_int(b_.as_int_unchecked());
|
||||
}
|
||||
return {a, b};
|
||||
}
|
||||
|
||||
SymIntNode SymInt::toSymIntNodeImpl() const {
|
||||
SymNode SymInt::toSymNodeImpl() const {
|
||||
TORCH_CHECK(is_symbolic());
|
||||
return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned());
|
||||
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
||||
}
|
||||
|
||||
c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) {
|
||||
SymInt::SymInt(SymNode sin_sp) {
|
||||
TORCH_CHECK(sin_sp->is_int());
|
||||
auto ptr = static_cast<uint64_t>(
|
||||
reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
|
||||
auto rep = (ptr & ~MASK) | IS_SYM;
|
||||
return c10::SymInt(UNCHECKED, static_cast<int64_t>(rep));
|
||||
data_ = static_cast<int64_t>(rep);
|
||||
}
|
||||
|
||||
int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
||||
if (!is_symbolic()) {
|
||||
return data_;
|
||||
}
|
||||
SymIntNode a = toSymIntNodeImpl();
|
||||
SymNode a = toSymNodeImpl();
|
||||
return a->guard_int(file, line);
|
||||
}
|
||||
|
||||
@ -49,7 +48,7 @@ SymInt::operator SymFloat() const {
|
||||
if (!is_symbolic()) {
|
||||
return SymFloat(double(data_));
|
||||
}
|
||||
return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float());
|
||||
return SymFloat(toSymNodeImpl()->sym_float());
|
||||
}
|
||||
|
||||
SymInt SymInt::operator+(SymInt sci) const {
|
||||
@ -57,7 +56,7 @@ SymInt SymInt::operator+(SymInt sci) const {
|
||||
return SymInt(data_ + sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->add(res[1]));
|
||||
return SymInt(res[0]->add(res[1]));
|
||||
}
|
||||
|
||||
SymInt SymInt::operator-(SymInt sci) const {
|
||||
@ -65,7 +64,7 @@ SymInt SymInt::operator-(SymInt sci) const {
|
||||
return SymInt(data_ - sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->sub(res[1]));
|
||||
return SymInt(res[0]->sub(res[1]));
|
||||
}
|
||||
|
||||
SymInt SymInt::operator*(SymInt sci) const {
|
||||
@ -73,7 +72,7 @@ SymInt SymInt::operator*(SymInt sci) const {
|
||||
return SymInt(data_ * sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->mul(res[1]));
|
||||
return SymInt(res[0]->mul(res[1]));
|
||||
}
|
||||
|
||||
SymInt SymInt::operator/(SymInt sci) const {
|
||||
@ -81,7 +80,7 @@ SymInt SymInt::operator/(SymInt sci) const {
|
||||
return SymInt(data_ / sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->floordiv(res[1]));
|
||||
return SymInt(res[0]->floordiv(res[1]));
|
||||
}
|
||||
|
||||
SymInt SymInt::operator%(SymInt sci) const {
|
||||
@ -89,7 +88,7 @@ SymInt SymInt::operator%(SymInt sci) const {
|
||||
return SymInt(data_ % sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->mod(res[1]));
|
||||
return SymInt(res[0]->mod(res[1]));
|
||||
}
|
||||
|
||||
bool SymInt::operator==(SymInt sci) const {
|
||||
@ -141,14 +140,14 @@ SymInt SymInt::min(SymInt sci) const {
|
||||
return std::min(data_, sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->min(res[1]));
|
||||
return SymInt(res[0]->min(res[1]));
|
||||
}
|
||||
SymInt SymInt::max(SymInt sci) const {
|
||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
||||
return std::max(data_, sci.data_);
|
||||
}
|
||||
auto res = normalize_symints(*this, sci);
|
||||
return SymInt::toSymInt(res[0]->max(res[1]));
|
||||
return SymInt(res[0]->max(res[1]));
|
||||
}
|
||||
|
||||
void SymInt::operator*=(SymInt sci) {
|
||||
@ -193,7 +192,7 @@ SymInt SymInt::operator*(int64_t sci) const {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, SymInt s) {
|
||||
if (s.is_symbolic()) {
|
||||
os << s.toSymIntNodeImpl()->str();
|
||||
os << s.toSymNodeImpl()->str();
|
||||
} else {
|
||||
os << s.as_int_unchecked();
|
||||
}
|
||||
@ -202,7 +201,7 @@ std::ostream& operator<<(std::ostream& os, SymInt s) {
|
||||
|
||||
SymInt operator-(SymInt s) {
|
||||
if (s.is_symbolic()) {
|
||||
return SymInt::toSymInt(s.toSymIntNodeImpl()->neg());
|
||||
return SymInt(s.toSymNodeImpl()->neg());
|
||||
} else {
|
||||
return SymInt(-s.as_int_unchecked());
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
@ -12,24 +12,19 @@ namespace c10 {
|
||||
|
||||
class SymFloat;
|
||||
|
||||
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
|
||||
// represent concrete dimension values.
|
||||
// SymInt represents either a regular int64_t, or a symbolic integer
|
||||
// (represented in a type erased way as SymNode). The intention is for SymInt
|
||||
// to represent symbolic sizes that arise when doing shape computation in
|
||||
// operator kernels. This allows for tracing through programs without baking in
|
||||
// concrete sizes into kernel calls.
|
||||
//
|
||||
// `SymInt` is also a data type in Pytorch that can be used in function schemas
|
||||
// to enable tracing.
|
||||
// SymInt has an API equivalent to int64_t. In particular, it is a value type.
|
||||
// Internally, SymInt is represented in a clever packed way, so that it only
|
||||
// occupies one word of space; but morally, it is a union between an int64_t
|
||||
// and an intrusive pointer to SymNodeImpl.
|
||||
//
|
||||
// `SymInt` is introduced to enable tracing arithmetic
|
||||
// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will
|
||||
// allow LTC and AOTAutograd representing dynamic shapes in expression graphs
|
||||
// faithfully without baking in concrete dimension values.
|
||||
//
|
||||
// To trace the operations, SymInt will overload arithmetic operators (e.g. +,
|
||||
// -, *) and will provide overloads taking SymInt for commonly used math
|
||||
// functions.
|
||||
//
|
||||
// SymInt will be extenteded to represent a union structure Union[int64_t,
|
||||
// SymIntNodeImpl*] which will be implemented as a single packed int64_t field
|
||||
// named data_.
|
||||
// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
|
||||
// is_int() returns true
|
||||
|
||||
class C10_API SymInt {
|
||||
public:
|
||||
@ -44,6 +39,7 @@ class C10_API SymInt {
|
||||
TORCH_CHECK(!is_symbolic());
|
||||
};
|
||||
SymInt() : data_(0) {}
|
||||
SymInt(SymNode n);
|
||||
|
||||
// unchecked c-tor accepting raw `data_`
|
||||
// One appropriate use for this is when you are constructing a symint
|
||||
@ -55,7 +51,7 @@ class C10_API SymInt {
|
||||
// temporary and then use the move constructor/assignment
|
||||
SymInt(const SymInt& s) : data_(0) {
|
||||
if (s.is_symbolic()) {
|
||||
*this = SymInt::toSymInt(s.toSymIntNodeImpl());
|
||||
*this = SymInt(s.toSymNodeImpl());
|
||||
} else {
|
||||
data_ = s.data_;
|
||||
}
|
||||
@ -67,7 +63,7 @@ class C10_API SymInt {
|
||||
SymInt& operator=(const SymInt& s) {
|
||||
if (this != &s) {
|
||||
if (s.is_symbolic()) {
|
||||
*this = SymInt::toSymInt(s.toSymIntNodeImpl());
|
||||
*this = SymInt(s.toSymNodeImpl());
|
||||
} else {
|
||||
data_ = s.data_;
|
||||
}
|
||||
@ -76,7 +72,7 @@ class C10_API SymInt {
|
||||
}
|
||||
SymInt& operator=(SymInt&& s) {
|
||||
if (this != &s) {
|
||||
release_(); // release the current SymIntNode if any
|
||||
release_(); // release the current SymNode if any
|
||||
data_ = s.data_;
|
||||
if (s.is_symbolic())
|
||||
s.data_ = 0;
|
||||
@ -86,31 +82,31 @@ class C10_API SymInt {
|
||||
|
||||
SymInt clone() const {
|
||||
if (is_symbolic()) {
|
||||
return toSymIntNodeImplUnowned()->clone()->toSymInt();
|
||||
return SymInt(toSymNodeImplUnowned()->clone());
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
SymIntNodeImpl* toSymIntNodeImplUnowned() const {
|
||||
SymNodeImpl* toSymNodeImplUnowned() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
|
||||
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
||||
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
||||
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
||||
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
|
||||
return static_cast<SymIntNodeImpl*>(
|
||||
return static_cast<SymNodeImpl*>(
|
||||
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
|
||||
}
|
||||
|
||||
void release_() {
|
||||
if (is_symbolic()) {
|
||||
SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal
|
||||
SymNode::reclaim(toSymNodeImplUnowned()); // steal
|
||||
}
|
||||
}
|
||||
|
||||
SymIntNodeImpl* release() && {
|
||||
SymNodeImpl* release() && {
|
||||
#ifndef C10_MOBILE
|
||||
TORCH_INTERNAL_ASSERT(is_symbolic());
|
||||
auto* r = toSymIntNodeImplUnowned();
|
||||
auto* r = toSymNodeImplUnowned();
|
||||
data_ = 0; // transfer ownership
|
||||
return r;
|
||||
#else
|
||||
@ -118,8 +114,7 @@ class C10_API SymInt {
|
||||
#endif
|
||||
}
|
||||
|
||||
SymIntNode toSymIntNodeImpl() const;
|
||||
static c10::SymInt toSymInt(SymIntNode sin);
|
||||
SymNode toSymNodeImpl() const;
|
||||
|
||||
~SymInt() {
|
||||
release_();
|
||||
|
@ -1,11 +0,0 @@
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
c10::SymInt SymIntNodeImpl::toSymInt() {
|
||||
auto sit_sp = SymIntNode::reclaim_copy(this);
|
||||
return SymInt::toSymInt(sit_sp);
|
||||
}
|
||||
|
||||
} // namespace c10
|
3
c10/core/SymNodeImpl.cpp
Normal file
3
c10/core/SymNodeImpl.cpp
Normal file
@ -0,0 +1,3 @@
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {} // namespace c10
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SymFloatNodeImpl.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
@ -10,13 +9,12 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class SymInt;
|
||||
class SymIntNodeImpl;
|
||||
class SymNodeImpl;
|
||||
using SymNode = c10::intrusive_ptr<SymNodeImpl>;
|
||||
|
||||
class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
|
||||
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
c10::SymInt toSymInt();
|
||||
virtual ~SymIntNodeImpl(){};
|
||||
virtual ~SymNodeImpl(){};
|
||||
|
||||
template <typename T>
|
||||
c10::intrusive_ptr<T> dyn_cast() const {
|
||||
@ -24,66 +22,87 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
|
||||
}
|
||||
|
||||
// these could be pure virtual when we implement LTC versions
|
||||
virtual SymIntNode add(const SymIntNode& other) {
|
||||
virtual bool is_int() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode sub(const SymIntNode& other) {
|
||||
virtual bool is_float() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode mul(const SymIntNode& other) {
|
||||
virtual SymNode add(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode truediv(const SymIntNode& other) {
|
||||
virtual SymNode sub(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode floordiv(const SymIntNode& other) {
|
||||
virtual SymNode mul(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode mod(const SymIntNode& other) {
|
||||
virtual SymNode truediv(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode eq(const SymIntNode& other) {
|
||||
virtual SymNode pow(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode ne(const SymIntNode& other) {
|
||||
virtual SymNode floordiv(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode gt(const SymIntNode& other) {
|
||||
virtual SymNode mod(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode lt(const SymIntNode& other) {
|
||||
virtual SymNode eq(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode le(const SymIntNode& other) {
|
||||
virtual SymNode ne(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode ge(const SymIntNode& other) {
|
||||
virtual SymNode gt(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode ceil() {
|
||||
virtual SymNode lt(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode neg() {
|
||||
virtual SymNode le(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode min(const SymIntNode& other) {
|
||||
virtual SymNode ge(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode max(const SymIntNode& other) {
|
||||
virtual SymNode ceil() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymIntNode clone() {
|
||||
virtual SymNode floor() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymFloatNode sym_float() {
|
||||
virtual SymNode neg() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode min(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode max(const SymNode& other) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode clone() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode sym_int() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymIntNode wrap(int64_t num) {
|
||||
virtual SymNode sym_float() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
virtual SymNode wrap_int(int64_t num) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode wrap_float(double num) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual int64_t guard_int(const char* file, int64_t line) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual double guard_float(const char* file, int64_t line) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual int64_t int_() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
@ -1,7 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
using namespace c10;
|
||||
#ifndef C10_MOBILE
|
||||
@ -20,12 +20,6 @@ TEST(SymIntTest, ConcreteInts) {
|
||||
check(-4611686018427387904LL);
|
||||
}
|
||||
|
||||
TEST(SymIntTest, AddNode) {
|
||||
auto n = c10::make_intrusive<SymIntNodeImpl>();
|
||||
auto i = n->toSymInt();
|
||||
EXPECT_TRUE(i.is_symbolic());
|
||||
}
|
||||
|
||||
TEST(SymIntTest, CheckRange) {
|
||||
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
|
||||
}
|
||||
|
@ -335,8 +335,8 @@ coverage_ignore_classes = [
|
||||
"Quantize",
|
||||
# torch.utils.backcompat
|
||||
"Warning",
|
||||
"SymIntNode",
|
||||
"SymFloatNode",
|
||||
"SymInt",
|
||||
"SymFloat",
|
||||
]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
|
@ -605,7 +605,7 @@ class PytreeThunk:
|
||||
return x
|
||||
return pytree.tree_unflatten(x, self.spec)
|
||||
|
||||
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymIntNode, torch.SymFloatNode]
|
||||
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymInt, torch.SymFloat]
|
||||
|
||||
|
||||
def aot_function(
|
||||
|
@ -209,7 +209,7 @@ def _tensor_nbytes(numel, dtype):
|
||||
|
||||
def _size_of(node: fx.Node) -> int:
|
||||
def to_size_hint(s):
|
||||
if isinstance(s, torch.SymIntNode):
|
||||
if isinstance(s, torch.SymInt):
|
||||
py_s = s.get_pyobj()
|
||||
return py_s.shape_env.size_hint(py_s.expr)
|
||||
assert isinstance(s, int)
|
||||
|
@ -18,6 +18,8 @@ cond = PyOperator('cond')
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
def _unwrap_proxy(e):
|
||||
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
|
||||
return e
|
||||
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
|
||||
|
||||
assert isinstance(operands, list), "Cond operands must be a list of tensors"
|
||||
|
@ -1447,35 +1447,29 @@ TEST(TestSymInt, AddSymbolicInt) {
|
||||
}
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
TEST(TestSymInt, TestIntrusive) {
|
||||
auto a = c10::make_intrusive<c10::SymIntNodeImpl>();
|
||||
auto b = c10::make_intrusive<c10::SymIntNodeImpl>();
|
||||
ASSERT_EQ(a.use_count(), 1);
|
||||
ASSERT_EQ(b.use_count(), 1);
|
||||
auto as = a->toSymInt();
|
||||
auto bs = b->toSymInt();
|
||||
ASSERT_EQ(a.use_count(), 2);
|
||||
ASSERT_EQ(b.use_count(), 2);
|
||||
as = bs;
|
||||
ASSERT_EQ(a.use_count(), 1);
|
||||
ASSERT_EQ(b.use_count(), 3);
|
||||
}
|
||||
|
||||
class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
class TestSymNodeImpl : public c10::SymNodeImpl {
|
||||
public:
|
||||
TestSymIntNodeImpl(int64_t i) : i_(i) {}
|
||||
explicit TestSymNodeImpl(int64_t i) : i_(i) {}
|
||||
|
||||
bool is_int() override {
|
||||
return true;
|
||||
};
|
||||
|
||||
bool is_float() override {
|
||||
return false;
|
||||
};
|
||||
|
||||
bool bool_() override {
|
||||
return static_cast<bool>(i_);
|
||||
};
|
||||
|
||||
#define OPDEF3(NAME, OP, RET) \
|
||||
RET NAME(const c10::SymIntNode& other) override { \
|
||||
return make_intrusive<TestSymIntNodeImpl>( \
|
||||
this->i_ OP dynamic_cast<TestSymIntNodeImpl*>(other.get())->i_); \
|
||||
#define OPDEF3(NAME, OP, RET) \
|
||||
RET NAME(const c10::SymNode& other) override { \
|
||||
return make_intrusive<TestSymNodeImpl>( \
|
||||
this->i_ OP dynamic_cast<TestSymNodeImpl*>(other.get())->i_); \
|
||||
}
|
||||
|
||||
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymIntNode)
|
||||
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode)
|
||||
OPDEF2(add, +)
|
||||
OPDEF2(sub, -)
|
||||
OPDEF2(mul, *)
|
||||
@ -1494,17 +1488,19 @@ class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
int64_t i_;
|
||||
};
|
||||
|
||||
TEST(TestSymInt, TestSymIntToSymIntNodeDispatch) {
|
||||
TEST(TestSymInt, TestSymIntToSymNodeDispatch) {
|
||||
auto get = [](c10::SymInt si) {
|
||||
auto node = si.toSymIntNodeImpl();
|
||||
return dynamic_cast<TestSymIntNodeImpl*>(node.get())->i_;
|
||||
auto node = si.toSymNodeImpl();
|
||||
return dynamic_cast<TestSymNodeImpl*>(node.get())->i_;
|
||||
};
|
||||
|
||||
std::vector<int64_t> inputs{0, 1, -1, 4, -4, 777, -777};
|
||||
for (auto i : inputs) {
|
||||
for (auto j : inputs) {
|
||||
auto a = c10::make_intrusive<TestSymIntNodeImpl>(i)->toSymInt();
|
||||
auto b = c10::make_intrusive<TestSymIntNodeImpl>(j)->toSymInt();
|
||||
auto a = c10::SymInt(
|
||||
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(i)));
|
||||
auto b = c10::SymInt(
|
||||
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(j)));
|
||||
ASSERT_EQ(get(a + b), i + j);
|
||||
ASSERT_EQ(get(a - b), i - j);
|
||||
ASSERT_EQ(get(a * b), i * j);
|
||||
|
@ -12,8 +12,9 @@ import itertools
|
||||
import io
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch import SymInt
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -116,9 +117,6 @@ def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
|
||||
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
||||
|
||||
|
||||
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
|
||||
|
||||
def create_symint(shape_env, i):
|
||||
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
||||
|
||||
@ -156,8 +154,8 @@ class TestPySymInt(TestCase):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
self.assertTrue(not isinstance(x.shape[0], PySymInt))
|
||||
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
|
||||
self.assertTrue(not isinstance(x.shape[0], SymNode))
|
||||
self.assertTrue(isinstance(x.shape[0], SymInt))
|
||||
|
||||
self.assertTrue(x.shape[0] == 5)
|
||||
self.assertTrue(x.shape[1] == 4)
|
||||
@ -165,17 +163,17 @@ class TestPySymInt(TestCase):
|
||||
|
||||
self.assertTrue(x.size()[0], 5)
|
||||
self.assertTrue(x.size()[1], 4)
|
||||
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
|
||||
self.assertTrue(isinstance(x.size()[1], SymInt))
|
||||
self.assertTrue(x.size()[2] == 3)
|
||||
|
||||
self.assertTrue(x.size(0) == 5)
|
||||
self.assertTrue(x.size(1) == 4)
|
||||
self.assertTrue(x.size(2) == 3)
|
||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
||||
self.assertTrue(isinstance(x.size(2), SymInt))
|
||||
|
||||
offset = create_symint(shape_env, 2)
|
||||
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
||||
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
|
||||
self.assertTrue(isinstance(y.storage_offset(), SymInt))
|
||||
self.assertTrue(y.storage_offset() == 2)
|
||||
|
||||
offset = 2
|
||||
@ -267,7 +265,7 @@ class TestPySymInt(TestCase):
|
||||
def test_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
||||
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
|
||||
self.assertIsInstance(x.stride()[0], SymInt)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_size_expressions(self):
|
||||
@ -290,7 +288,7 @@ class TestPySymInt(TestCase):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
r = sym_float(x.shape[0])
|
||||
self.assertTrue(isinstance(r, torch.SymFloatNode))
|
||||
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_aten_ops(self):
|
||||
@ -320,13 +318,13 @@ class TestPySymInt(TestCase):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
r = torch.empty(a0, device='meta')
|
||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
||||
self.assertIsInstance(r.shape[0], SymInt)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_guard_int(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(a0.guard_int(), 2)
|
||||
self.assertEqual(guard_int(a0), 2)
|
||||
self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)")
|
||||
|
||||
@skipIfNoSympy
|
||||
@ -347,7 +345,9 @@ class TestPySymInt(TestCase):
|
||||
assert func == torch.ops.aten.add.Tensor
|
||||
|
||||
nonlocal sym_int_encountered
|
||||
sym_int_encountered = kwargs["alpha"] is a0
|
||||
# WARNING: do not do identity tests on the outer
|
||||
# SymInt/SymFloat, they are NOT STABLE
|
||||
sym_int_encountered = kwargs["alpha"].node is a0.node
|
||||
kwargs["alpha"] = 0
|
||||
return func(*args)
|
||||
|
||||
|
@ -1,391 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch._C import _disabled_torch_function_impl
|
||||
import torch.fx
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
|
||||
import unittest
|
||||
import torch
|
||||
import operator
|
||||
import itertools
|
||||
import io
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
import sympy
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
|
||||
|
||||
meta_funcs = {}
|
||||
|
||||
|
||||
def register_meta(op):
|
||||
def decorator(f):
|
||||
def add_func(op):
|
||||
meta_funcs[op] = f
|
||||
tree_map(add_func, op)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
||||
def binary_meta(a, b):
|
||||
return a.new_empty(a.shape)
|
||||
|
||||
|
||||
@register_meta(aten.cat.default)
|
||||
def cat_meta(tensors, dim=0):
|
||||
concat_length = 0
|
||||
shape = tensors[0].shape
|
||||
for tensor in tensors:
|
||||
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
||||
if idx == dim:
|
||||
concat_length = concat_length + length
|
||||
else:
|
||||
assert length == common_length
|
||||
new_shape = list(shape)
|
||||
new_shape[dim] = concat_length
|
||||
return tensors[0].new_empty(new_shape)
|
||||
|
||||
|
||||
@register_meta([aten.narrow_copy.default])
|
||||
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
||||
shape = []
|
||||
for i, x in enumerate(a.shape):
|
||||
if i == dim:
|
||||
shape.append(length)
|
||||
else:
|
||||
shape.append(x)
|
||||
return a.new_empty(tuple(shape))
|
||||
|
||||
|
||||
@register_meta([aten.expand.default])
|
||||
def expand_symint_meta(a, size, implicit=False):
|
||||
return a.new_empty(size)
|
||||
|
||||
|
||||
def create_contiguous(shape):
|
||||
strides = [1]
|
||||
for dim in reversed(shape[:-1]):
|
||||
strides.append(dim * strides[-1])
|
||||
return list(reversed(strides))
|
||||
|
||||
|
||||
class FakeSymbolicTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
|
||||
# TODO: this is wrong in general
|
||||
sym_stride = create_contiguous(sym_shape)
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, sym_shape,
|
||||
sym_stride, storage_offset,
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
device=device,
|
||||
)
|
||||
return r
|
||||
|
||||
__torch_function__ = _disabled_torch_function_impl
|
||||
|
||||
def new_empty(self, shape):
|
||||
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
||||
if func_overload in meta_funcs:
|
||||
return meta_funcs[func_overload](*args, **kwargs)
|
||||
|
||||
if func_overload == torch.ops.aten.new_empty.default:
|
||||
self = args[0]
|
||||
shape = args[1]
|
||||
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
|
||||
|
||||
raise RuntimeError(f"operator {func_overload} not supported")
|
||||
|
||||
|
||||
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
|
||||
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
||||
|
||||
|
||||
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
|
||||
|
||||
def create_symint(shape_env, i):
|
||||
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestPySymInt(TestCase):
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
symints = []
|
||||
for i in range(2, 5):
|
||||
symints.append((i, create_symint(shape_env, i)))
|
||||
|
||||
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
|
||||
|
||||
for op in ops:
|
||||
for args in itertools.permutations(symints, 2):
|
||||
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
|
||||
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_reverse_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
a = create_symint(shape_env, 2)
|
||||
self.assertTrue(5 // a == 5 // 2)
|
||||
|
||||
a = create_symint(shape_env, 2)
|
||||
self.assertTrue(5 * a == 5 * 2)
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_roundtrip(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
self.assertTrue(not isinstance(x.shape[0], PySymInt))
|
||||
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
|
||||
|
||||
self.assertTrue(x.shape[0] == 5)
|
||||
self.assertTrue(x.shape[1] == 4)
|
||||
self.assertTrue(x.shape[2], 3)
|
||||
|
||||
self.assertTrue(x.size()[0], 5)
|
||||
self.assertTrue(x.size()[1], 4)
|
||||
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
|
||||
self.assertTrue(x.size()[2] == 3)
|
||||
|
||||
self.assertTrue(x.size(0) == 5)
|
||||
self.assertTrue(x.size(1) == 4)
|
||||
self.assertTrue(x.size(2) == 3)
|
||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
||||
|
||||
offset = create_symint(shape_env, 2)
|
||||
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
||||
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
|
||||
self.assertTrue(y.storage_offset() == 2)
|
||||
|
||||
offset = 2
|
||||
z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset)
|
||||
self.assertTrue(isinstance(z.storage_offset(), int))
|
||||
self.assertTrue(z.storage_offset() == 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_binary(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
||||
|
||||
z = x + y
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# broadcasting
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
z = x + y
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_args(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
||||
LAST_DIM = 2
|
||||
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
||||
self.assertTrue(z.shape[2] == y.shape[2])
|
||||
|
||||
# arithmetic expr with two symints
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
||||
self.assertTrue(z.shape[2] == 2)
|
||||
|
||||
# arithmetic expr with a symint and python int
|
||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
||||
self.assertTrue(z.shape[2] == 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_vargs(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
||||
|
||||
# varargs
|
||||
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# shape list
|
||||
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(x.shape[0], y.shape[1], 3)
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python symints and ints in a list
|
||||
z = y.expand((x.shape[0], y.shape[1], 3))
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python symints and ints
|
||||
z = y.expand(5, y.shape[1], x.shape[2])
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
# mixed python ints and symints in a list
|
||||
z = y.expand((5, y.shape[1], x.shape[2]))
|
||||
self.assertTrue(z.shape[0] == 5)
|
||||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
z = y.expand((y.shape[1],))
|
||||
z = y.expand(y.shape[1])
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
||||
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_size_expressions(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
expand_x = x.expand(x.shape[0], x.shape[0])
|
||||
if expand_x.shape[0] > 3:
|
||||
result = expand_x + expand_x
|
||||
else:
|
||||
result = expand_x + expand_x
|
||||
|
||||
gt_op = shape_env.guards[0][0]
|
||||
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
||||
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_to_float(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
r = sym_float(x.shape[0])
|
||||
self.assertTrue(isinstance(r, torch.SymFloatNode))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_aten_ops(self):
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
||||
|
||||
def test_fx_trace_intlist(self):
|
||||
class CustomModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
bs, c, h, w = x.shape
|
||||
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
||||
|
||||
m = CustomModule()
|
||||
x = torch.rand(1, 3, 4, 4)
|
||||
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
||||
# tuple of ints, not tuple
|
||||
torch.fx.symbolic_trace(m)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_meta_symint(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
r = torch.empty(a0, device='meta')
|
||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_guard_int(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(a0.guard_int(), 2)
|
||||
self.assertEqual(str(shape_env.guards[0][0]), "s0")
|
||||
self.assertEqual(shape_env.guards[0][1], 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_conversion(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_as_scalar(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
|
||||
sym_int_encountered = False
|
||||
|
||||
class TestSymInt(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
assert func == torch.ops.aten.add.Tensor
|
||||
|
||||
nonlocal sym_int_encountered
|
||||
sym_int_encountered = kwargs["alpha"] is a0
|
||||
kwargs["alpha"] = 0
|
||||
return func(*args)
|
||||
|
||||
x = torch.rand([4, 4])
|
||||
with TestSymInt():
|
||||
y = torch.add(x, x, alpha=a0)
|
||||
|
||||
self.assertTrue(sym_int_encountered)
|
||||
|
||||
@skipIfNoSympy
|
||||
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
|
||||
def test_print_readable_with_symints(self, mock_stdout):
|
||||
def f(a, b):
|
||||
dim0 = a.shape[0] + b.shape[0]
|
||||
dim1 = a.shape[1] + b.shape[1]
|
||||
d = a.new_empty(dim0, dim1)
|
||||
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
|
||||
return d
|
||||
|
||||
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
|
||||
fx_g.print_readable()
|
||||
|
||||
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
|
||||
class f(torch.nn.Module):
|
||||
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
|
||||
# No stacktrace found for following nodes
|
||||
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
|
||||
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
|
||||
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
|
||||
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
|
||||
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
|
||||
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
|
||||
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
||||
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
|
||||
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
|
||||
return (getitem, getitem_1)""") # noqa: B950
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -875,8 +875,7 @@ def forward(self, a_1):
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||
sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None
|
||||
pow_1 = sym_float ** 0.5; sym_float = None
|
||||
pow_1 = sym_size ** 0.5; sym_size = None
|
||||
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
|
||||
return div""")
|
||||
|
||||
@ -949,7 +948,7 @@ def forward(self, a_1):
|
||||
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
|
||||
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
|
||||
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
|
||||
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr)
|
||||
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].node.expr)
|
||||
|
||||
def test_metadata_fresh(self):
|
||||
def f(x):
|
||||
|
@ -207,8 +207,8 @@ class TestPublicBindings(TestCase):
|
||||
"StreamObjType",
|
||||
"StringType",
|
||||
"SUM",
|
||||
"SymFloatNode",
|
||||
"SymIntNode",
|
||||
"SymFloat",
|
||||
"SymInt",
|
||||
"TensorType",
|
||||
"ThroughputBenchmark",
|
||||
"TracingState",
|
||||
|
@ -291,7 +291,7 @@ PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
|
||||
for (auto i : c10::irange(prop.size())) {
|
||||
auto si = prop[i];
|
||||
if (si.is_symbolic()) {
|
||||
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
|
||||
auto py_symint = py::cast(si).release().ptr();
|
||||
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
|
||||
} else {
|
||||
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked()));
|
||||
@ -313,7 +313,7 @@ return PyLong_FromUnsignedLong((int64_t) prop);
|
||||
"""
|
||||
|
||||
GETTER_BODY_SYMINT = """\
|
||||
return prop.is_symbolic() ? py::cast(prop.toSymIntNodeImpl()).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
|
||||
return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
|
||||
"""
|
||||
|
||||
GETTER_BODY_DOUBLE = """\
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <Python.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include "torch/csrc/autograd/generated/Functions.h"
|
||||
#include "torch/csrc/autograd/python_cpp_function.h"
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
|
@ -240,12 +240,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
|
||||
if (jit::tracer::isTracing()) {
|
||||
return wrap(jit::tracer::getNumelOf(self_));
|
||||
} else {
|
||||
auto si = self_.sym_numel();
|
||||
if (si.is_symbolic()) {
|
||||
return py::cast(si.toSymIntNodeImpl()).release().ptr();
|
||||
} else {
|
||||
return THPUtils_packInt64(si.as_int_unchecked());
|
||||
}
|
||||
return py::cast(self_.sym_numel()).release().ptr();
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
@ -722,7 +722,7 @@ def gen_pyi(
|
||||
binop += "_"
|
||||
out_suffix = ""
|
||||
unsorted_tensor_method_hints[binop].append(
|
||||
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode]{})"
|
||||
"def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})"
|
||||
" -> Tensor: ...".format(binop, out_suffix)
|
||||
)
|
||||
for binop in ["add", "sub"]:
|
||||
@ -732,7 +732,7 @@ def gen_pyi(
|
||||
binop += "_"
|
||||
out_suffix = ""
|
||||
unsorted_tensor_method_hints[binop].append(
|
||||
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode], "
|
||||
"def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], "
|
||||
"*, alpha: Optional[Number]=1{})"
|
||||
" -> Tensor: ...".format(binop, out_suffix)
|
||||
)
|
||||
|
@ -169,20 +169,6 @@ class Future(object):
|
||||
|
||||
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
||||
|
||||
class SymIntNode(object):
|
||||
def get_pyobj(self) -> Any: ...
|
||||
|
||||
@staticmethod
|
||||
def new_symint(obj) -> SymIntNode: ...
|
||||
|
||||
class SymFloatNode(object):
|
||||
def get_pyobj(self) -> Any: ...
|
||||
|
||||
@staticmethod
|
||||
def new_symfloat(obj) -> SymFloatNode: ...
|
||||
|
||||
def __ceil__(self) -> SymIntNode: ...
|
||||
|
||||
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
|
||||
class MobileOptimizerType:
|
||||
...
|
||||
|
@ -47,7 +47,7 @@ __all__ = [
|
||||
'is_deterministic_algorithms_warn_only_enabled',
|
||||
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
|
||||
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
||||
'set_warn_always', 'is_warn_always_enabled',
|
||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||
]
|
||||
|
||||
################################################################################
|
||||
@ -196,6 +196,67 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
import torch._C as _C
|
||||
|
||||
class SymInt:
|
||||
"""
|
||||
Like an int (including magic methods), but redirects all operations on the
|
||||
wrapped node. This is used in particular to symbolically record operations
|
||||
in the symbolic shape workflow.
|
||||
"""
|
||||
|
||||
def __init__(self, node):
|
||||
from torch.fx.experimental.symbolic_shapes import SymNode
|
||||
assert isinstance(node, SymNode)
|
||||
# This field MUST be named node; C++ binding code assumes that this
|
||||
# class has a field named node that stores SymNode
|
||||
self.node = node
|
||||
|
||||
# Magic methods installed later
|
||||
|
||||
def __bool__(self):
|
||||
return self.node.bool_()
|
||||
|
||||
def __int__(self):
|
||||
return self.node.int_()
|
||||
|
||||
def __sym_float__(self):
|
||||
return SymFloat(self.node.sym_float())
|
||||
|
||||
def __repr__(self):
|
||||
return self.node.str()
|
||||
|
||||
# For BC; direct access of node is OK too
|
||||
def get_pyobj(self):
|
||||
return self.node
|
||||
|
||||
class SymFloat:
|
||||
"""
|
||||
Like an float (including magic methods), but redirects all operations on the
|
||||
wrapped node. This is used in particular to symbolically record operations
|
||||
in the symbolic shape workflow.
|
||||
"""
|
||||
|
||||
def __init__(self, node):
|
||||
from torch.fx.experimental.symbolic_shapes import SymNode
|
||||
assert isinstance(node, SymNode)
|
||||
# This field MUST be named node; C++ binding code assumes that this
|
||||
# class has a field named node that stores SymNode
|
||||
self.node = node
|
||||
|
||||
# Magic methods installed later
|
||||
|
||||
def __bool__(self):
|
||||
return self.node.bool_()
|
||||
|
||||
def __sym_int__(self):
|
||||
return SymInt(self.node.sym_int())
|
||||
|
||||
def __repr__(self):
|
||||
return self.node.str()
|
||||
|
||||
# For BC; direct access of node is OK too
|
||||
def get_pyobj(self):
|
||||
return self.node
|
||||
|
||||
# Check to see if we can load C extensions, and if not provide some guidance
|
||||
# on what the problem might be.
|
||||
try:
|
||||
@ -941,7 +1002,6 @@ from ._linalg_utils import ( # type: ignore[misc]
|
||||
lstsq,
|
||||
)
|
||||
|
||||
|
||||
def _register_device_module(device_type, module):
|
||||
r"""Register an external runtime module of the specific :attr:`device_type`
|
||||
supported by torch.
|
||||
@ -971,3 +1031,6 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
|
||||
import torch.cuda._sanitizer as csan
|
||||
|
||||
csan.enable_cuda_sanitizer()
|
||||
|
||||
# Populate magic methods on SymInt and SymFloat
|
||||
import torch.fx.experimental.symbolic_shapes
|
||||
|
@ -337,7 +337,7 @@ class TensorVariable(VariableTracker):
|
||||
from . import UserDefinedObjectVariable
|
||||
|
||||
return UserDefinedObjectVariable(example_value)
|
||||
elif isinstance(example_value, torch.SymIntNode):
|
||||
elif isinstance(example_value, torch.SymInt):
|
||||
proxy.node.meta["example_value"] = example_value
|
||||
return cls(proxy, **options)
|
||||
else:
|
||||
|
@ -40,11 +40,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
else:
|
||||
size, stride = self._shape_env.create_symbolic_sizes_strides(ex)
|
||||
|
||||
size = [
|
||||
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in size
|
||||
]
|
||||
size = [i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in size]
|
||||
stride = [
|
||||
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in stride
|
||||
i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in stride
|
||||
]
|
||||
return size, stride
|
||||
|
||||
|
@ -392,8 +392,8 @@ def _elementwise_meta(
|
||||
# Number case
|
||||
# NOTE: this case is not currently exercised
|
||||
# TODO: fix number type promotion (bool, complex->float)
|
||||
assert not isinstance(number, torch.SymIntNode), "NYI"
|
||||
assert not isinstance(number, torch.SymFloatNode), "NYI"
|
||||
assert not isinstance(number, torch.SymInt), "NYI"
|
||||
assert not isinstance(number, torch.SymFloat), "NYI"
|
||||
return TensorMeta(number)
|
||||
|
||||
|
||||
@ -932,7 +932,7 @@ bitwise_xor = _make_elementwise_binary_prim(
|
||||
# div prim performs truncation division on integer inputs
|
||||
# and true division for floating and complex inputs
|
||||
def _div_aten(a, b):
|
||||
is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or (
|
||||
is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
|
||||
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
|
||||
)
|
||||
|
||||
|
@ -42,18 +42,18 @@ ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
|
||||
StrideType = Union[List[int], Tuple[int, ...]]
|
||||
DimsType = Union[int, List[int], Tuple[int, ...]]
|
||||
DimsSequenceType = Union[List[int], Tuple[int, ...]]
|
||||
# TODO: Type[torch.SymIntNode], Type[torch.SymFloatNode]
|
||||
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
|
||||
NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
|
||||
# TODO: This needs a lot more type annotations
|
||||
# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode]
|
||||
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
|
||||
NumberType = Union[bool, int, float, complex]
|
||||
|
||||
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode)
|
||||
Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
|
||||
# I don't call it Integral because numbers.Integral includes bool, but IntLike
|
||||
# does not
|
||||
Dim = int
|
||||
IntLike = (int, torch.SymIntNode)
|
||||
FloatLike = (float, torch.SymFloatNode)
|
||||
IntLike = (int, torch.SymInt)
|
||||
FloatLike = (float, torch.SymFloat)
|
||||
IntWithoutSymInt = int
|
||||
FloatWithoutSymFloat = float
|
||||
DeviceLikeType = Union[str, torch.device]
|
||||
@ -1113,10 +1113,10 @@ class RETURN_TYPE(Enum):
|
||||
|
||||
|
||||
# TODO: when NumberType contains the sym types, can simplify this
|
||||
def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type:
|
||||
if isinstance(x, torch.SymIntNode):
|
||||
def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
|
||||
if isinstance(x, torch.SymInt):
|
||||
return int
|
||||
elif isinstance(x, torch.SymFloatNode):
|
||||
elif isinstance(x, torch.SymFloat):
|
||||
return float
|
||||
else:
|
||||
return type(x)
|
||||
|
@ -656,7 +656,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
return args[0].fake_device
|
||||
|
||||
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
|
||||
flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs))
|
||||
flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
|
||||
has_symbolic_sizes = (
|
||||
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors])
|
||||
or len(flat_symints) > 0
|
||||
|
@ -59,7 +59,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
|
||||
TORCH_CHECK(
|
||||
!torch::jit::tracer::isTracing(),
|
||||
"JIT Tracing of SymInts isn't supported");
|
||||
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
|
||||
auto py_symint = py::cast(si).release().ptr();
|
||||
if (!py_symint)
|
||||
throw python_error();
|
||||
PyTuple_SET_ITEM(ret.get(), i, py_symint);
|
||||
@ -98,7 +98,7 @@ static PyObject* THPSize_pynew(
|
||||
if (THPUtils_checkLong(item)) {
|
||||
continue;
|
||||
}
|
||||
if (torch::is_symint_node(item)) {
|
||||
if (torch::is_symint(item)) {
|
||||
continue;
|
||||
}
|
||||
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
|
||||
@ -135,7 +135,7 @@ static PyObject* THPSize_repr(THPSize* self) {
|
||||
auto item = PyTuple_GET_ITEM(self, i);
|
||||
auto ih = py::handle(item);
|
||||
|
||||
repr += torch::is_symint_node(ih)
|
||||
repr += torch::is_symint(ih)
|
||||
? std::string(py::str(ih))
|
||||
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
|
||||
}
|
||||
|
@ -2646,9 +2646,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
|
||||
"Cannot call numel on a tensor with symbolic shapes/strides");
|
||||
return self->sym_numel_default();
|
||||
}
|
||||
return torch::is_symint_node(out)
|
||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
return torch::is_symint(out) ? out.cast<c10::SymInt>()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
||||
@ -2669,9 +2668,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
||||
if (out.is(py::none())) {
|
||||
return self->sym_storage_offset_default();
|
||||
}
|
||||
return torch::is_symint_node(out)
|
||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
return torch::is_symint(out) ? out.cast<c10::SymInt>()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
||||
@ -2701,9 +2699,8 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
||||
py::list symints;
|
||||
for (auto it = out.begin(); it != out.end(); it++) {
|
||||
auto elm = *it;
|
||||
auto si = torch::is_symint_node(elm)
|
||||
? elm.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
||||
: c10::SymInt{py::cast<int64_t>(elm)};
|
||||
auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>()
|
||||
: c10::SymInt{py::cast<int64_t>(elm)};
|
||||
symints.append(si.as_int_unchecked());
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
||||
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
||||
#endif
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
@ -99,7 +99,6 @@
|
||||
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
|
||||
#include <torch/csrc/utils/cpp_stacktraces.h>
|
||||
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/signal_handler.h>
|
||||
@ -126,249 +125,11 @@ using c10::Argument;
|
||||
using c10::FunctionSchema;
|
||||
using c10::SchemaArgType;
|
||||
using c10::SchemaArgument;
|
||||
using c10::SymFloat;
|
||||
using c10::SymFloatNode;
|
||||
using c10::SymIntNode;
|
||||
using c10::SymNode;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
using torch::utils::SchemaInfo;
|
||||
|
||||
static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
|
||||
return torch::is_symint_node(b) ? b.cast<c10::SymIntNode>()
|
||||
: a->wrap(b.cast<int64_t>());
|
||||
}
|
||||
|
||||
static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) {
|
||||
if (torch::is_symfloat_node(b)) {
|
||||
return b.cast<c10::SymFloatNode>();
|
||||
} else if (torch::is_symint_node(b)) {
|
||||
return b.cast<c10::SymIntNode>()->sym_float();
|
||||
} else {
|
||||
return a->wrap(b.cast<double>());
|
||||
}
|
||||
}
|
||||
|
||||
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
public:
|
||||
PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() {
|
||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
virtual SymIntNode clone() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("clone")();
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual SymIntNode wrap(int64_t num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap")(num);
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual bool bool_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
virtual int64_t guard_int(const char* file, int64_t line) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
|
||||
}
|
||||
|
||||
virtual int64_t int_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__int__")().cast<int64_t>();
|
||||
}
|
||||
|
||||
SymFloatNode sym_float() override;
|
||||
|
||||
virtual std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
}
|
||||
|
||||
virtual SymIntNode dispatch_common_(
|
||||
const char* fname,
|
||||
const SymIntNode& other) {
|
||||
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual SymIntNode dispatch_common_(const char* fname) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)();
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual SymIntNode add(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode sub(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode mul(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymFloatNode truediv(const SymIntNode& other) override;
|
||||
|
||||
virtual SymIntNode floordiv(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode mod(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode eq(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode gt(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode lt(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode le(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode ge(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode min(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
virtual SymIntNode max(const SymIntNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
virtual SymIntNode ceil() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
virtual SymIntNode neg() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
|
||||
public:
|
||||
PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() {
|
||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
virtual SymFloatNode wrap(double num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap")(num);
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
||||
}
|
||||
|
||||
virtual std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("__str__")().cast<std::string>();
|
||||
}
|
||||
|
||||
SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) {
|
||||
auto pother = dynamic_cast<PythonSymFloatNodeImpl*>(other.get());
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
||||
}
|
||||
|
||||
SymFloatNode add(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode sub(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode mul(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode truediv(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode pow(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode eq(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode gt(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode lt(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode le(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymFloatNode ge(const SymFloatNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
SymIntNode ceil() override;
|
||||
SymIntNode floor() override;
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) {
|
||||
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("truediv")(pother->getPyObj());
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
||||
}
|
||||
|
||||
SymFloatNode PythonSymIntNodeImpl::sym_float() {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(
|
||||
getPyObj().attr("__sym_float__")());
|
||||
}
|
||||
|
||||
SymIntNode PythonSymFloatNodeImpl::ceil() {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("ceil")();
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
SymIntNode PythonSymFloatNodeImpl::floor() {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("floor")();
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using autograd::variable_list;
|
||||
@ -1381,276 +1142,41 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
});
|
||||
|
||||
auto symint_class =
|
||||
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
|
||||
.def_static(
|
||||
"new_symint",
|
||||
[](py::object obj) -> c10::SymIntNode {
|
||||
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
|
||||
})
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](c10::SymIntNode a) -> py::object {
|
||||
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
return py::none();
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->add(a);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__rsub__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->sub(a);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->mul(a);
|
||||
})
|
||||
.def(
|
||||
"__truediv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->truediv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rtruediv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->truediv(a);
|
||||
})
|
||||
.def(
|
||||
"__floordiv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->floordiv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rfloordiv__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->floordiv(a);
|
||||
})
|
||||
.def(
|
||||
"__mod__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->mod(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmod__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return snb->mod(a);
|
||||
})
|
||||
.def(
|
||||
"__pow__",
|
||||
[](c10::SymIntNode a, py::object b) -> py::object {
|
||||
if (PyFloat_Check(b.ptr())) {
|
||||
auto float_a = a->sym_float();
|
||||
return py::cast(
|
||||
float_a->pow(float_a->wrap(py::cast<double>(b))));
|
||||
}
|
||||
// TODO: integer pow
|
||||
return py::reinterpret_borrow<py::object>(Py_NotImplemented);
|
||||
})
|
||||
// TODO: rpow
|
||||
.def(
|
||||
"__eq__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](c10::SymIntNode a, py::object b) {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__le__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->le(snb);
|
||||
})
|
||||
.def(
|
||||
"__ge__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def(
|
||||
"__ceil__",
|
||||
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
|
||||
.def(
|
||||
"__neg__",
|
||||
[](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); })
|
||||
.def(
|
||||
"__min__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->min(snb);
|
||||
})
|
||||
.def(
|
||||
"__max__",
|
||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
||||
auto snb = toSymIntNode(a, b);
|
||||
return a->max(snb);
|
||||
})
|
||||
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
|
||||
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
|
||||
// Intentionally don't set file line, as the Python backtrace matters
|
||||
// more here
|
||||
.def(
|
||||
"guard_int",
|
||||
[](c10::SymIntNode a) { return a->guard_int(nullptr, 0); })
|
||||
.def(
|
||||
"__sym_float__",
|
||||
[](c10::SymIntNode a) {
|
||||
// TODO: remove dynamic cast when sym_float is in base class
|
||||
auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get());
|
||||
TORCH_INTERNAL_ASSERT(psn);
|
||||
return psn->sym_float();
|
||||
})
|
||||
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
|
||||
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
|
||||
|
||||
py::class_<c10::SymFloatNodeImpl, c10::SymFloatNode>(m, "SymFloatNode")
|
||||
.def_static(
|
||||
"new_symfloat",
|
||||
[](py::object obj) -> c10::SymFloatNode {
|
||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(obj);
|
||||
})
|
||||
.def(
|
||||
"__add__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->add(snb);
|
||||
})
|
||||
.def(
|
||||
"__radd__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return snb->add(a);
|
||||
})
|
||||
.def(
|
||||
"__sub__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->sub(snb);
|
||||
})
|
||||
.def(
|
||||
"__mul__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->mul(snb);
|
||||
})
|
||||
.def(
|
||||
"__rmul__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return snb->mul(a);
|
||||
})
|
||||
.def(
|
||||
"__truediv__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->truediv(snb);
|
||||
})
|
||||
.def(
|
||||
"__rtruediv__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return snb->truediv(a);
|
||||
})
|
||||
.def(
|
||||
"__eq__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->eq(snb);
|
||||
})
|
||||
.def(
|
||||
"__gt__",
|
||||
[](c10::SymFloatNode a, py::object b) {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->gt(snb);
|
||||
})
|
||||
.def(
|
||||
"__lt__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->lt(snb);
|
||||
})
|
||||
.def(
|
||||
"__le__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->le(snb);
|
||||
})
|
||||
.def(
|
||||
"__ge__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->ge(snb);
|
||||
})
|
||||
.def(
|
||||
"__pow__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return a->pow(snb);
|
||||
})
|
||||
.def(
|
||||
"__rpow__",
|
||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
||||
auto snb = toSymFloatNode(a, b);
|
||||
return snb->pow(a);
|
||||
})
|
||||
.def(
|
||||
"__ceil__",
|
||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
|
||||
.def(
|
||||
"__floor__",
|
||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
|
||||
.def(
|
||||
"get_pyobj",
|
||||
[](c10::SymFloatNode a) -> py::object {
|
||||
if (auto* psn = dynamic_cast<PythonSymFloatNodeImpl*>(a.get())) {
|
||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
||||
}
|
||||
return py::none();
|
||||
})
|
||||
.def("__str__", [](c10::SymFloatNode a) { return a->str(); });
|
||||
// NB: This isn't actually used for regular PyTorch symbolic tracing;
|
||||
// XLA is what needs this
|
||||
#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
|
||||
#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); })
|
||||
#define SYMNODE_BINARY(n) \
|
||||
.def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
|
||||
auto symnode_class =
|
||||
py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
|
||||
// These DO NOT install magic methods; the SymInt/SymFloat wrapper in
|
||||
// Python is responsible for this
|
||||
SYMNODE_UNARY(clone)
|
||||
// Named these for consistency with inner python class, but maybe
|
||||
// should change the python side
|
||||
SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_)
|
||||
SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2(
|
||||
__sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub)
|
||||
SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow)
|
||||
SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(
|
||||
eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt)
|
||||
SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min)
|
||||
SYMNODE_BINARY(max) SYMNODE_UNARY(ceil)
|
||||
SYMNODE_UNARY(floor) SYMNODE_UNARY(neg)
|
||||
// Intentionally don't set file line, as the
|
||||
// Python backtrace matters more here
|
||||
.def(
|
||||
"guard_int",
|
||||
[](c10::SymNode a) {
|
||||
return a->guard_int(nullptr, 0);
|
||||
})
|
||||
.def(
|
||||
"__str__",
|
||||
[](c10::SymNode a) { return a->str(); })
|
||||
.def("__repr__", [](c10::SymNode a) {
|
||||
return a->str();
|
||||
});
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
||||
|
@ -80,10 +80,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
|
||||
} else if (THPUtils_checkDouble(obj.ptr())) {
|
||||
scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
|
||||
} else if (torch::is_symint_node(py::handle(obj))) {
|
||||
} else if (torch::is_symint(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
scalar = at::Scalar(7777777);
|
||||
} else if (torch::is_symfloat_node(py::handle(obj))) {
|
||||
} else if (torch::is_symfloat(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
||||
} else {
|
||||
@ -161,12 +161,12 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
return py::cast<int64_t>(obj);
|
||||
}
|
||||
case TypeKind::SymIntType:
|
||||
if (torch::is_symint_node(obj.ptr())) {
|
||||
if (torch::is_symint(obj.ptr())) {
|
||||
return py::cast<c10::SymInt>(obj);
|
||||
}
|
||||
return py::cast<int64_t>(obj);
|
||||
case TypeKind::SymFloatType:
|
||||
if (torch::is_symfloat_node(obj.ptr())) {
|
||||
if (torch::is_symfloat(obj.ptr())) {
|
||||
return py::cast<c10::SymFloat>(obj);
|
||||
}
|
||||
return py::cast<double>(obj);
|
||||
@ -253,7 +253,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
bool is_symbolic = false;
|
||||
for (auto it = obj.begin(); it != obj.end(); it++) {
|
||||
auto elm = *it;
|
||||
if (torch::is_symint_node(elm)) {
|
||||
if (torch::is_symint(elm)) {
|
||||
is_symbolic = true;
|
||||
break;
|
||||
}
|
||||
@ -269,7 +269,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
for (auto it = obj.begin(); it != obj.end(); it++) {
|
||||
auto elm = *it;
|
||||
// TODO: what about SymInt conversion to SymFloat?
|
||||
if (torch::is_symfloat_node(elm)) {
|
||||
if (torch::is_symfloat(elm)) {
|
||||
is_symbolic = true;
|
||||
break;
|
||||
}
|
||||
@ -442,9 +442,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
} else if (PyComplex_CheckExact(obj.ptr())) {
|
||||
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
|
||||
return static_cast<c10::complex<double>>(c_obj);
|
||||
} else if (torch::is_symint_node(obj)) {
|
||||
} else if (torch::is_symint(obj)) {
|
||||
return py::cast<c10::SymInt>(obj);
|
||||
} else if (torch::is_symfloat_node(obj)) {
|
||||
} else if (torch::is_symfloat(obj)) {
|
||||
return py::cast<c10::SymFloat>(obj);
|
||||
} else {
|
||||
throw py::cast_error(
|
||||
|
@ -136,10 +136,10 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
|
||||
|
||||
inline Value GetSymIntValue(c10::SymInt a) {
|
||||
return Value(
|
||||
a.is_symbolic() ? dynamic_cast<torch::lazy::SymIntNodeImpl*>(
|
||||
a.toSymIntNodeImpl().get())
|
||||
->node_
|
||||
: MakeScalar(a.as_int_unchecked(), at::kLong),
|
||||
a.is_symbolic()
|
||||
? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get())
|
||||
->node_
|
||||
: MakeScalar(a.as_int_unchecked(), at::kLong),
|
||||
0);
|
||||
}
|
||||
|
||||
|
@ -451,11 +451,11 @@ std::vector<Shape> compute_shape_expand(
|
||||
std::vector<int64_t> target_size(_sizes.size());
|
||||
for (const auto idx : c10::irange(_sizes.size())) {
|
||||
if (_sizes[idx].is_symbolic()) {
|
||||
c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl();
|
||||
auto* lazySymIntNode =
|
||||
dynamic_cast<torch::lazy::SymIntNodeImpl*>(symbolicIntNode.get());
|
||||
TORCH_INTERNAL_ASSERT(lazySymIntNode);
|
||||
auto size_node = lazySymIntNode->node_;
|
||||
c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl();
|
||||
auto* lazySymNode =
|
||||
dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get());
|
||||
TORCH_INTERNAL_ASSERT(lazySymNode);
|
||||
auto size_node = lazySymNode->node_;
|
||||
auto static_value =
|
||||
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
|
||||
->getStaticValue();
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||
@ -10,12 +10,9 @@
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl {
|
||||
class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
|
||||
public:
|
||||
SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
|
||||
c10::SymIntNode add(const c10::SymIntNode& other) override {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
|
||||
NodePtr node_;
|
||||
};
|
||||
|
||||
|
@ -685,7 +685,7 @@ static bool is_int_list(
|
||||
// NB: do NOT check that the later arguments are ints, as this is
|
||||
// BC-breaking for FX
|
||||
for (int i = 1; i < len; i++) {
|
||||
if (torch::is_symint_node(
|
||||
if (torch::is_symint(
|
||||
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
|
||||
if (failed_idx != nullptr) {
|
||||
*failed_idx = i;
|
||||
@ -716,9 +716,9 @@ static bool is_int_list(
|
||||
static bool is_int_or_symint(PyObject* obj) {
|
||||
// THPUtils_checkIndex may call __index__ or __int__
|
||||
// which may have side effects if obj is a symint node
|
||||
// so we do `is_symint_node` check first
|
||||
// so we do `is_symint` check first
|
||||
// TODO: maybe we should be using checkLong here?
|
||||
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj);
|
||||
return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
|
||||
}
|
||||
|
||||
static bool is_int_or_symint_list(
|
||||
@ -1570,13 +1570,13 @@ at::Tensor PythonArgs::tensor_slow(int i) {
|
||||
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
|
||||
// because although Scalar supports SymInt/SymFloat, the subsequent
|
||||
// conversion to Tensor does not. Instead, do it out of band.
|
||||
} else if (torch::is_symint_node(py::handle(obj))) {
|
||||
} else if (torch::is_symint(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
// This scalar value doesn't matter, it shouldn't ever actually
|
||||
// get read out. Make it a big and weird looking number to help
|
||||
// people figure out if there's aproblem.
|
||||
scalar = at::Scalar(7777777);
|
||||
} else if (torch::is_symfloat_node(py::handle(obj))) {
|
||||
} else if (torch::is_symfloat(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
||||
} else {
|
||||
@ -1633,11 +1633,11 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
|
||||
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
||||
}
|
||||
|
||||
if (torch::is_symint_node(arg)) {
|
||||
if (torch::is_symint(arg)) {
|
||||
return at::Scalar(py::cast<c10::SymInt>(arg));
|
||||
}
|
||||
|
||||
if (torch::is_symfloat_node(arg)) {
|
||||
if (torch::is_symfloat(arg)) {
|
||||
return at::Scalar(py::cast<c10::SymFloat>(arg));
|
||||
}
|
||||
|
||||
|
@ -61,6 +61,7 @@
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/python_symnode.h>
|
||||
#include <torch/csrc/utils/six.h>
|
||||
|
||||
#include <ATen/PythonTorchFunctionTLS.h>
|
||||
@ -69,7 +70,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymIntNodeImpl.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
@ -78,30 +79,6 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
||||
inline bool is_symint_node(py::handle obj) {
|
||||
auto static tp_symn = py::type::of<c10::SymIntNodeImpl>();
|
||||
if (py::isinstance(obj, tp_symn)) {
|
||||
TORCH_CHECK(
|
||||
!jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_symfloat_node(py::handle obj) {
|
||||
auto static tp_symn = py::type::of<c10::SymFloatNodeImpl>();
|
||||
if (py::isinstance(obj, tp_symn)) {
|
||||
TORCH_CHECK(
|
||||
!jit::tracer::isTracing(), "JIT tracing of SymFloats isn't supported!");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
template <>
|
||||
@ -109,8 +86,10 @@ struct type_caster<c10::SymInt> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
|
||||
bool load(py::handle src, bool) {
|
||||
if (torch::is_symint_node(src)) {
|
||||
value = src.cast<c10::SymIntNodeImpl*>()->toSymInt();
|
||||
if (torch::is_symint(src)) {
|
||||
value = c10::SymInt(static_cast<c10::SymNode>(
|
||||
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
|
||||
src.attr("node"))));
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -126,8 +105,15 @@ struct type_caster<c10::SymInt> {
|
||||
c10::SymInt si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
return si.is_symbolic() ? py::cast(si.toSymIntNodeImpl()).release()
|
||||
: py::cast(si.expect_int()).release();
|
||||
if (si.is_symbolic()) {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||
si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symint_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_int_unchecked()).release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -136,8 +122,10 @@ struct type_caster<c10::SymFloat> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
|
||||
bool load(py::handle src, bool) {
|
||||
if (torch::is_symfloat_node(src)) {
|
||||
value = src.cast<c10::SymFloatNodeImpl*>()->toSymFloat();
|
||||
if (torch::is_symfloat(src)) {
|
||||
value = c10::SymFloat(static_cast<c10::SymNode>(
|
||||
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
|
||||
src.attr("node"))));
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -153,8 +141,15 @@ struct type_caster<c10::SymFloat> {
|
||||
c10::SymFloat si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
return si.is_symbolic() ? py::cast(si.toSymFloatNodeImpl()).release()
|
||||
: py::cast(si.expect_float()).release();
|
||||
if (si.is_symbolic()) {
|
||||
// TODO: generalize this to work with C++ backed class
|
||||
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||
si.toSymNodeImpl().get());
|
||||
TORCH_INTERNAL_ASSERT(py_node);
|
||||
return torch::get_symfloat_class()(py_node->getPyObj()).release();
|
||||
} else {
|
||||
return py::cast(si.as_float_unchecked()).release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
@ -167,8 +162,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) {
|
||||
}
|
||||
#endif
|
||||
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
|
||||
torch::is_symint_node(py::handle(obj)) ||
|
||||
torch::is_symfloat_node(py::handle(obj));
|
||||
torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj));
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
@ -574,7 +568,7 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
|
||||
|
||||
inline PyObject* toPyObject(c10::SymInt symint) {
|
||||
if (symint.is_symbolic()) {
|
||||
auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr();
|
||||
auto r = py::cast(symint).release().ptr();
|
||||
TORCH_INTERNAL_ASSERT(r);
|
||||
return r;
|
||||
} else {
|
||||
@ -609,8 +603,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
||||
size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
|
||||
}
|
||||
|
||||
if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) {
|
||||
auto si = py::handle(args[i]).cast<c10::SymIntNodeImpl*>()->toSymInt();
|
||||
if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
|
||||
auto si = py::handle(args[i]).cast<c10::SymInt>();
|
||||
return std::vector<c10::SymInt>(size1, si);
|
||||
}
|
||||
|
||||
@ -652,9 +646,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
||||
res.push_back(var.item<int64_t>());
|
||||
} else {
|
||||
try {
|
||||
if (is_symint_node(py::handle(obj))) {
|
||||
res.push_back(
|
||||
py::handle(obj).cast<c10::SymIntNodeImpl*>()->toSymInt());
|
||||
if (is_symint(py::handle(obj))) {
|
||||
res.push_back(py::handle(obj).cast<c10::SymInt>());
|
||||
} else {
|
||||
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
|
||||
}
|
||||
|
19
torch/csrc/utils/python_symnode.cpp
Normal file
19
torch/csrc/utils/python_symnode.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include <torch/csrc/utils/python_symnode.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
py::handle get_symint_class() {
|
||||
// NB: leak
|
||||
static py::handle symint_class =
|
||||
py::object(py::module::import("torch").attr("SymInt")).release();
|
||||
return symint_class;
|
||||
}
|
||||
|
||||
py::handle get_symfloat_class() {
|
||||
// NB: leak
|
||||
static py::handle symfloat_class =
|
||||
py::object(py::module::import("torch").attr("SymFloat")).release();
|
||||
return symfloat_class;
|
||||
}
|
||||
|
||||
} // namespace torch
|
182
torch/csrc/utils/python_symnode.h
Normal file
182
torch/csrc/utils/python_symnode.h
Normal file
@ -0,0 +1,182 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
TORCH_PYTHON_API py::handle get_symint_class();
|
||||
TORCH_PYTHON_API py::handle get_symfloat_class();
|
||||
|
||||
// NB: These functions must not be called too early, otherwise torch not setup.
|
||||
// Alternate design is to have torch "register" the object to us
|
||||
inline bool is_symint(py::handle obj) {
|
||||
return py::isinstance(obj, get_symint_class());
|
||||
}
|
||||
inline bool is_symfloat(py::handle obj) {
|
||||
return py::isinstance(obj, get_symfloat_class());
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
// This c10::SymNodeImpl simply backends to a Python object that
|
||||
// implements the API. The Python object is the source of truth,
|
||||
// this is just an adapter so C++ calls can get to the object.
|
||||
class PythonSymNodeImpl : public c10::SymNodeImpl {
|
||||
public:
|
||||
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
|
||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||
pyobj.release().ptr(), getPyInterpreter());
|
||||
};
|
||||
|
||||
c10::SymNode wrap_int(int64_t num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap_int")(num);
|
||||
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||
}
|
||||
|
||||
c10::SymNode wrap_float(double num) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr("wrap_float")(num);
|
||||
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||
}
|
||||
|
||||
bool bool_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("bool_")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
bool is_int() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("is_int")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
bool is_float() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("is_float")().is(py::handle(Py_True));
|
||||
}
|
||||
|
||||
int64_t guard_int(const char* file, int64_t line) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
|
||||
}
|
||||
|
||||
double guard_float(const char* file, int64_t line) override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("guard_float")(file, line).cast<double>();
|
||||
}
|
||||
|
||||
int64_t int_() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("int_")().cast<int64_t>();
|
||||
}
|
||||
|
||||
std::string str() override {
|
||||
py::gil_scoped_acquire acquire;
|
||||
return getPyObj().attr("str")().cast<std::string>();
|
||||
}
|
||||
|
||||
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
|
||||
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
|
||||
TORCH_CHECK(pother);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||
}
|
||||
|
||||
c10::SymNode dispatch_common_(const char* fname) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)();
|
||||
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||
}
|
||||
|
||||
c10::SymNode add(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode sub(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode mul(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode truediv(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode pow(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode floordiv(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode mod(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode eq(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode gt(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode lt(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode le(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode ge(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode min(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
c10::SymNode max(const c10::SymNode& other) override {
|
||||
return dispatch_common_(__FUNCTION__, other);
|
||||
}
|
||||
|
||||
c10::SymNode ceil() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
c10::SymNode floor() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
c10::SymNode neg() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
c10::SymNode clone() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
c10::SymNode sym_int() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
c10::SymNode sym_float() override {
|
||||
return dispatch_common_(__FUNCTION__);
|
||||
}
|
||||
|
||||
py::handle getPyObj() {
|
||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||
}
|
||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace torch
|
@ -21,8 +21,9 @@ import operator
|
||||
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
|
||||
from torch._subclasses import FakeTensor
|
||||
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
|
||||
from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
|
||||
from torch.fx import Proxy
|
||||
from torch import SymInt, SymFloat
|
||||
|
||||
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"]
|
||||
aten = torch.ops.aten
|
||||
@ -55,27 +56,27 @@ def decompose(decomposition_table):
|
||||
proxy_slot = object()
|
||||
no_default = object()
|
||||
|
||||
py_sym_types = (
|
||||
PySymInt,
|
||||
PySymFloat,
|
||||
)
|
||||
py_sym_types = (SymInt, SymFloat)
|
||||
def is_sym_node(node):
|
||||
assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
|
||||
return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
|
||||
|
||||
def set_proxy_slot(obj, tracer, proxy):
|
||||
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary())
|
||||
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
|
||||
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload]
|
||||
assert isinstance(d, weakref.WeakKeyDictionary)
|
||||
d[tracer] = proxy
|
||||
|
||||
def has_proxy_slot(obj, tracer):
|
||||
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
|
||||
return get_proxy_slot(obj, tracer, False, lambda _: True)
|
||||
|
||||
# the default argument is what to return if the slot is not set.
|
||||
# the transform argument is handy if you need to extract a subfield from
|
||||
# the successfully looked up result (but NOT the default.)
|
||||
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
|
||||
d = obj.__dict__.get(proxy_slot)
|
||||
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
|
||||
d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload]
|
||||
if not d:
|
||||
if default is no_default:
|
||||
raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
|
||||
@ -130,10 +131,8 @@ def track_tensor(tensor, proxy, *, constant, tracer):
|
||||
def try_set_proxy_slot(outer_s, proxy_callable, *args):
|
||||
assert callable(proxy_callable)
|
||||
if isinstance(outer_s, SymInt):
|
||||
inner_s = outer_s.get_pyobj()
|
||||
assert isinstance(inner_s, py_sym_types)
|
||||
|
||||
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args))
|
||||
inner_s = outer_s.node
|
||||
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, outer_s, *args))
|
||||
|
||||
# The basic idea is that we need to associate each tensor/SymInt
|
||||
# with a Proxy. How do we setup this association? We just store
|
||||
@ -198,7 +197,7 @@ class _ProxyTensor:
|
||||
|
||||
def fetch_sym_proxy(tracer):
|
||||
def inner(e):
|
||||
n = e.get_pyobj()
|
||||
n = e.node
|
||||
if n.constant is not None:
|
||||
return n.constant
|
||||
else:
|
||||
@ -400,8 +399,8 @@ class PythonKeyTracer(Tracer):
|
||||
|
||||
return self.create_node('get_attr', qualname, (), {})
|
||||
elif isinstance(a, (SymInt, SymFloat)):
|
||||
assert a.get_pyobj().constant is not None
|
||||
return a.get_pyobj().constant
|
||||
assert a.node.constant is not None
|
||||
return a.node.constant
|
||||
return super().create_arg(a)
|
||||
|
||||
|
||||
@ -432,7 +431,7 @@ def wrap_key(f, tensors, tracer):
|
||||
)
|
||||
out = pytree.tree_map_only(
|
||||
(SymInt, SymFloat),
|
||||
lambda t: get_proxy_slot(t.get_pyobj(), tracer)(),
|
||||
lambda t: get_proxy_slot(t.node, tracer)(),
|
||||
out
|
||||
)
|
||||
return out
|
||||
@ -479,10 +478,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
return out
|
||||
|
||||
|
||||
SymInt = torch.SymIntNode
|
||||
SymFloat = torch.SymFloatNode
|
||||
|
||||
|
||||
class ProxySymDispatchMode(SymDispatchMode):
|
||||
def __init__(self, tracer):
|
||||
super().__init__()
|
||||
@ -501,10 +496,9 @@ class ProxySymDispatchMode(SymDispatchMode):
|
||||
finally:
|
||||
self.enable_tracing = old
|
||||
|
||||
def _compute_proxy(self, func, args, out):
|
||||
def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat]):
|
||||
n_args = tuple(
|
||||
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant
|
||||
if isinstance(a, py_sym_types) else a
|
||||
get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a
|
||||
for a in args
|
||||
)
|
||||
|
||||
@ -520,10 +514,11 @@ class ProxySymDispatchMode(SymDispatchMode):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Peephole optimize multiply by one
|
||||
# NB: be careful not to trigger guards here!
|
||||
if func == operator.mul:
|
||||
if isinstance(args[1], (PySymInt, PySymFloat)) and args[1].constant == 1:
|
||||
if isinstance(args[1], int) and args[1] == 1:
|
||||
return args[0]
|
||||
elif isinstance(args[0], (PySymInt, PySymFloat)) and args[0].constant == 1:
|
||||
elif isinstance(args[0], int) and args[0] == 1:
|
||||
return args[1]
|
||||
|
||||
# For speed, we assume there are no nested data structures
|
||||
@ -535,7 +530,7 @@ class ProxySymDispatchMode(SymDispatchMode):
|
||||
|
||||
# Delays tracing out the proxies on this op until we actually need it
|
||||
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
|
||||
set_proxy_slot(out, self.tracer, p_out_thunk)
|
||||
set_proxy_slot(out.node, self.tracer, p_out_thunk)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -10,6 +10,7 @@ import traceback
|
||||
import collections
|
||||
import textwrap
|
||||
from torch._subclasses.meta_utils import MetaConverter
|
||||
from torch import SymInt, SymFloat
|
||||
|
||||
try:
|
||||
import sympy # type: ignore[import]
|
||||
@ -21,8 +22,8 @@ except ImportError:
|
||||
aten = torch.ops.aten # type: ignore[has-type]
|
||||
|
||||
__all__ = [
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv",
|
||||
"SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv"
|
||||
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
||||
"SymDispatchMode", "sym_float", "FloorDiv", "guard_int", "wrap_node"
|
||||
]
|
||||
|
||||
SYM_FUNCTION_MODE = None
|
||||
@ -88,32 +89,38 @@ def _handle_sym_dispatch(func, args, kwargs):
|
||||
finally:
|
||||
SYM_FUNCTION_MODE = mode
|
||||
|
||||
def guard_int(a):
|
||||
if isinstance(a, SymInt):
|
||||
return a.node.guard_int("", 0) # NB: uses Python backtrace
|
||||
assert isinstance(a, int)
|
||||
return a
|
||||
|
||||
def sym_float(a):
|
||||
if hasattr(a, '__sym_float__'):
|
||||
return a.__sym_float__()
|
||||
elif isinstance(a, torch._C.SymFloatNode):
|
||||
if isinstance(a, SymFloat):
|
||||
return a
|
||||
elif hasattr(a, '__sym_float__'):
|
||||
return a.__sym_float__()
|
||||
return float(a)
|
||||
|
||||
def sym_int(a):
|
||||
if hasattr(a, '__sym_int__'):
|
||||
return a.__sym_int__()
|
||||
elif isinstance(a, torch._C.SymIntNode):
|
||||
if isinstance(a, SymInt):
|
||||
return a
|
||||
elif hasattr(a, '__sym_int__'):
|
||||
return a.__sym_int__()
|
||||
return int(a)
|
||||
|
||||
# TODO: An incomplete list
|
||||
# 1. Set variables to be equal when we do equality
|
||||
# 2. Specialize on 0/1 when we do subtraction
|
||||
class PySymInt(object):
|
||||
class SymNode:
|
||||
"""
|
||||
PySymInt objects are the primary "symbolic shape" objects that flow through
|
||||
our program. They're what sit under FakeTensor, and contains our primary
|
||||
implementation of symbolic shapes.
|
||||
This is a type erased SymInt/SymFloat which we use to do actual operations.
|
||||
End users don't touch this. Magic methods are NOT defined on this object.
|
||||
"""
|
||||
def __init__(self, expr, shape_env, constant=None):
|
||||
def __init__(self, expr, shape_env, pytype, constant=None):
|
||||
self._expr = expr
|
||||
self.shape_env = shape_env
|
||||
self.pytype = pytype
|
||||
self.constant = constant
|
||||
|
||||
@property
|
||||
@ -121,23 +128,49 @@ class PySymInt(object):
|
||||
self._update_expr()
|
||||
return self._expr
|
||||
|
||||
def wrap(self, num):
|
||||
return PySymInt(sympy.Integer(num), self.shape_env, constant=num)
|
||||
|
||||
def clone(self):
|
||||
return PySymInt(self.expr, self.shape_env, constant=self.constant)
|
||||
|
||||
def _update_expr(self):
|
||||
self._expr = self.shape_env.replace(self._expr)
|
||||
|
||||
def __str__(self):
|
||||
def to_node(self, num):
|
||||
if isinstance(num, (SymInt, SymFloat)):
|
||||
return num.node
|
||||
elif isinstance(num, int):
|
||||
return self.wrap_int(num)
|
||||
elif isinstance(num, float):
|
||||
return self.wrap_float(num)
|
||||
else:
|
||||
# NotImplementedError is important so that Python tries the
|
||||
# other magic method
|
||||
raise NotImplementedError(type(num))
|
||||
|
||||
def is_int(self):
|
||||
return self.pytype is int
|
||||
|
||||
def is_float(self):
|
||||
return self.pytype is float
|
||||
|
||||
def wrap_int(self, num):
|
||||
assert isinstance(num, int)
|
||||
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
|
||||
|
||||
def wrap_float(self, num):
|
||||
assert isinstance(num, float)
|
||||
return SymNode(sympy.Integer(num), self.shape_env, float, constant=num)
|
||||
|
||||
def clone(self):
|
||||
return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant)
|
||||
|
||||
def str(self):
|
||||
return f"{self.expr}"
|
||||
|
||||
def __str__(self):
|
||||
return self.str()
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.expr}"
|
||||
return self.str()
|
||||
|
||||
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
|
||||
def __int__(self):
|
||||
def int_(self):
|
||||
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
|
||||
|
||||
# You can manually trigger a guard with this function
|
||||
@ -146,28 +179,35 @@ class PySymInt(object):
|
||||
# guard occurred
|
||||
return int(self.shape_env.evaluate_expr(self.expr))
|
||||
|
||||
def __sym_float__(self):
|
||||
def guard_float(self, file, line):
|
||||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
# guard occurred
|
||||
return float(self.shape_env.evaluate_expr(self.expr))
|
||||
|
||||
def sym_float(self):
|
||||
if SYM_FUNCTION_MODE:
|
||||
return _handle_sym_dispatch(sym_float, (self,), {})
|
||||
r = _handle_sym_dispatch(sym_float, (wrap_node(self),), {})
|
||||
assert isinstance(r, (SymInt, SymFloat)), type(r)
|
||||
return r.node
|
||||
# TODO: consider constant prop here
|
||||
# TODO: wrapping the expr with sympy.Float doesn't seem to work, why
|
||||
# not?
|
||||
return PySymFloat(self.expr, self.shape_env)
|
||||
return SymNode(self.expr, self.shape_env, float)
|
||||
|
||||
def __bool__(self):
|
||||
def sym_int(self):
|
||||
raise NotImplementedError("sym_int NYI")
|
||||
"""
|
||||
if SYM_FUNCTION_MODE:
|
||||
return _handle_sym_dispatch(sym_int, (self,), {})
|
||||
# TODO: consider constant prop here
|
||||
# XXX: need to cast float to int in sympy; math.floor is wrong
|
||||
# because negatives round to zero
|
||||
return SymNode(self.expr, self.shape_env, int)
|
||||
"""
|
||||
|
||||
def bool_(self):
|
||||
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
|
||||
|
||||
class PySymFloat:
|
||||
def __init__(self, expr, shape_env, constant=None):
|
||||
self.expr = expr
|
||||
self.shape_env = shape_env
|
||||
self.constant = constant
|
||||
|
||||
def wrap(self, num):
|
||||
return PySymFloat(sympy.Float(num), self.shape_env, constant=num)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.expr}"
|
||||
|
||||
if HAS_SYMPY:
|
||||
class FloorDiv(sympy.Function):
|
||||
@ -238,32 +278,45 @@ unary_magic_methods = {
|
||||
|
||||
float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"}
|
||||
|
||||
def _make_magic(method, func, py_type):
|
||||
def wrap_node(x):
|
||||
if not isinstance(x, SymNode):
|
||||
return x
|
||||
if x.constant is not None:
|
||||
return x.constant
|
||||
if x.pytype is int:
|
||||
return SymInt(x)
|
||||
elif x.pytype is float:
|
||||
return SymFloat(x)
|
||||
else:
|
||||
raise AssertionError(f"unrecognized return type {x.pytype}")
|
||||
|
||||
def _make_node_magic(method, func):
|
||||
func = lru_cache(256)(func)
|
||||
|
||||
def magic_impl(self, other):
|
||||
def binary_magic_impl(self, other):
|
||||
if method in ["min", "max"]:
|
||||
op = getattr(builtins, method)
|
||||
else:
|
||||
op = getattr(operator, method)
|
||||
if SYM_FUNCTION_MODE:
|
||||
return _handle_sym_dispatch(op, (self, other), {})
|
||||
if isinstance(other, py_type):
|
||||
other_expr = other.expr
|
||||
else:
|
||||
assert isinstance(other, sympy.Expr)
|
||||
other_expr = other
|
||||
r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||
assert isinstance(r, (SymInt, SymFloat)), type(r)
|
||||
return r.node
|
||||
assert isinstance(other, SymNode)
|
||||
other_expr = other.expr
|
||||
# TODO: consider constant prop here
|
||||
expr = self.shape_env.replace(self.expr)
|
||||
other_expr = self.shape_env.replace(other_expr)
|
||||
out = func(expr, other_expr)
|
||||
out = sympy.expand(out)
|
||||
if method in ["truediv"]:
|
||||
return PySymFloat(out, self.shape_env)
|
||||
pytype = float
|
||||
else:
|
||||
# TODO: relational operators actually technically return a
|
||||
# PySymBool, this is a type error
|
||||
return py_type(out, self.shape_env)
|
||||
pytype = self.pytype
|
||||
|
||||
# TODO: relational operators actually technically return a
|
||||
# PySymBool, this is a type error
|
||||
return SymNode(out, self.shape_env, pytype)
|
||||
|
||||
def unary_magic_impl(self):
|
||||
if SYM_FUNCTION_MODE:
|
||||
@ -271,33 +324,55 @@ def _make_magic(method, func, py_type):
|
||||
op = getattr(math, method)
|
||||
else:
|
||||
op = getattr(operator, method)
|
||||
return _handle_sym_dispatch(op, (self,), {})
|
||||
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
|
||||
assert isinstance(r, (SymInt, SymFloat)), type(r)
|
||||
return r.node
|
||||
# TODO: consider constant prop here
|
||||
expr = self.shape_env.replace(self.expr)
|
||||
out = func(expr)
|
||||
out = sympy.expand(out)
|
||||
if method in ["ceil", "floor"]:
|
||||
return PySymInt(out, self.shape_env)
|
||||
pytype = int
|
||||
else:
|
||||
return py_type(out, self.shape_env)
|
||||
pytype = self.pytype
|
||||
|
||||
return SymNode(out, self.shape_env, pytype)
|
||||
|
||||
# this should be wrapped transparently into torch.SymIntNode
|
||||
if method in unary_magic_methods:
|
||||
setattr(py_type, method, unary_magic_impl)
|
||||
setattr(py_type, f"__{method}__", unary_magic_impl)
|
||||
setattr(SymNode, method, unary_magic_impl)
|
||||
else:
|
||||
setattr(py_type, method, magic_impl)
|
||||
setattr(py_type, f"__{method}__", magic_impl)
|
||||
if method in reflectable_magic_methods:
|
||||
setattr(py_type, f"__r{method}__", magic_impl)
|
||||
setattr(SymNode, method, binary_magic_impl)
|
||||
|
||||
for method, func in magic_methods.items():
|
||||
_make_magic(method, func, PySymInt)
|
||||
_make_node_magic(method, func)
|
||||
|
||||
def _make_user_magic(method, user_type):
|
||||
# User magic takes care of wrapping the other operand into a node,
|
||||
# so that our internal logic can assume everything is nodes
|
||||
|
||||
def unary_magic_impl(self):
|
||||
return wrap_node(getattr(self.node, method)())
|
||||
|
||||
def binary_magic_impl(self, other):
|
||||
return wrap_node(getattr(self.node, method)(self.node.to_node(other)))
|
||||
|
||||
def rbinary_magic_impl(self, other):
|
||||
return wrap_node(getattr(self.node.to_node(other), method)(self.node))
|
||||
|
||||
if method in unary_magic_methods:
|
||||
setattr(user_type, f"__{method}__", unary_magic_impl)
|
||||
else:
|
||||
setattr(user_type, f"__{method}__", binary_magic_impl)
|
||||
if method in reflectable_magic_methods:
|
||||
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
|
||||
|
||||
for method, func in magic_methods.items():
|
||||
_make_user_magic(method, SymInt)
|
||||
|
||||
for method, func in magic_methods.items():
|
||||
if method not in float_magic_methods:
|
||||
continue
|
||||
_make_magic(method, func, PySymFloat)
|
||||
_make_user_magic(method, SymFloat)
|
||||
|
||||
del method
|
||||
del func
|
||||
@ -390,9 +465,7 @@ class ShapeEnv(object):
|
||||
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
|
||||
|
||||
def create_symintnode(self, expr: "sympy.Expr"):
|
||||
py_sym_int = PySymInt(expr, self)
|
||||
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
|
||||
return cpp_sym_int
|
||||
return SymInt(SymNode(expr, self, int))
|
||||
|
||||
def create_symbol(self, val: int) -> "sympy.Expr":
|
||||
if not HAS_SYMPY:
|
||||
|
@ -498,7 +498,7 @@ class CodeGen(object):
|
||||
if isinstance(meta_val, FakeTensor):
|
||||
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
||||
elif isinstance(meta_val, py_sym_types):
|
||||
maybe_type_annotation = f': Sym({meta_val.expr})'
|
||||
maybe_type_annotation = f': Sym({meta_val})'
|
||||
elif isinstance(meta_val, TensorMetadata):
|
||||
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
||||
|
||||
|
Reference in New Issue
Block a user