mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Unify SymIntNode and SymFloatNode into SymNode (#87817)
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817 Approved by: https://github.com/albanD, https://github.com/anjali411
This commit is contained in:
committed by
PyTorch MergeBot
parent
2205f56f46
commit
1ff52225f1
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
095ee628212f0235ad0d6908bdd514123639fc86
|
1e9b8bdc75114ac6c16305c970be37a1cd2fdb1c
|
||||||
|
@ -439,7 +439,7 @@ command = [
|
|||||||
"""--error-description=\
|
"""--error-description=\
|
||||||
This line has an isinstance call that directly refers to \
|
This line has an isinstance call that directly refers to \
|
||||||
int or float. This is error-prone because you may also \
|
int or float. This is error-prone because you may also \
|
||||||
have wanted to allow SymIntNode or SymFloatNode in your test. \
|
have wanted to allow SymInt or SymFloat in your test. \
|
||||||
To suppress this lint, use an appropriate type alias defined \
|
To suppress this lint, use an appropriate type alias defined \
|
||||||
in torch._prims_common; use IntLike/FloatLike when you would accept \
|
in torch._prims_common; use IntLike/FloatLike when you would accept \
|
||||||
both regular and symbolic numbers, Dim for ints representing \
|
both regular and symbolic numbers, Dim for ints representing \
|
||||||
|
@ -95,7 +95,7 @@ c10::SymInt get_nbytes(const Tensor& value) {
|
|||||||
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
|
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
|
||||||
// Today, the two implementations of SymInt are in Python (proxy tensor),
|
// Today, the two implementations of SymInt are in Python (proxy tensor),
|
||||||
// and lazy tensor (LTC/XLA).
|
// and lazy tensor (LTC/XLA).
|
||||||
// LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl).
|
// LTC hasn't implemented SymInt support yet though
|
||||||
// Once it does, we should remove this check.
|
// Once it does, we should remove this check.
|
||||||
if (value.key_set().has(c10::DispatchKey::Python)) {
|
if (value.key_set().has(c10::DispatchKey::Python)) {
|
||||||
return value.storage().sym_nbytes();
|
return value.storage().sym_nbytes();
|
||||||
|
@ -562,7 +562,7 @@ public:
|
|||||||
IValue(c10::SymInt i) {
|
IValue(c10::SymInt i) {
|
||||||
if (i.is_symbolic()) {
|
if (i.is_symbolic()) {
|
||||||
tag = Tag::SymInt;
|
tag = Tag::SymInt;
|
||||||
payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release();
|
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||||
} else {
|
} else {
|
||||||
tag = Tag::Int;
|
tag = Tag::Int;
|
||||||
payload.u.as_int = i.as_int_unchecked();
|
payload.u.as_int = i.as_int_unchecked();
|
||||||
@ -578,7 +578,7 @@ public:
|
|||||||
IValue(c10::SymFloat i) {
|
IValue(c10::SymFloat i) {
|
||||||
if (i.is_symbolic()) {
|
if (i.is_symbolic()) {
|
||||||
tag = Tag::SymFloat;
|
tag = Tag::SymFloat;
|
||||||
payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release();
|
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||||
} else {
|
} else {
|
||||||
tag = Tag::Double;
|
tag = Tag::Double;
|
||||||
payload.u.as_double = i.as_float_unchecked();
|
payload.u.as_double = i.as_float_unchecked();
|
||||||
@ -812,10 +812,10 @@ public:
|
|||||||
// for both SymFloat and double
|
// for both SymFloat and double
|
||||||
if (s.isSymInt()) {
|
if (s.isSymInt()) {
|
||||||
tag = Tag::SymInt;
|
tag = Tag::SymInt;
|
||||||
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release();
|
payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
|
||||||
} else if (s.isSymFloat()) {
|
} else if (s.isSymFloat()) {
|
||||||
tag = Tag::SymFloat;
|
tag = Tag::SymFloat;
|
||||||
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release();
|
payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
|
||||||
} else if (s.isFloatingPoint()) {
|
} else if (s.isFloatingPoint()) {
|
||||||
tag = Tag::Double;
|
tag = Tag::Double;
|
||||||
payload.u.as_double = s.toDouble();
|
payload.u.as_double = s.toDouble();
|
||||||
|
@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& {
|
|||||||
inline c10::SymInt IValue::toSymInt() const {
|
inline c10::SymInt IValue::toSymInt() const {
|
||||||
AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
|
AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
|
||||||
if (isSymInt()) {
|
if (isSymInt()) {
|
||||||
return c10::SymInt::toSymInt(toIntrusivePtr<c10::SymIntNodeImpl>());
|
return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
|
||||||
} else {
|
} else {
|
||||||
return c10::SymInt(payload.u.as_int);
|
return c10::SymInt(payload.u.as_int);
|
||||||
}
|
}
|
||||||
@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const {
|
|||||||
inline c10::SymFloat IValue::toSymFloat() const {
|
inline c10::SymFloat IValue::toSymFloat() const {
|
||||||
AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
|
AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
|
||||||
if (isSymFloat()) {
|
if (isSymFloat()) {
|
||||||
return c10::SymFloat::toSymFloat(toIntrusivePtr<c10::SymFloatNodeImpl>());
|
return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
|
||||||
} else {
|
} else {
|
||||||
return c10::SymFloat(payload.u.as_double);
|
return c10::SymFloat(payload.u.as_double);
|
||||||
}
|
}
|
||||||
|
@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type {
|
|||||||
return "SymInt";
|
return "SymInt";
|
||||||
}
|
}
|
||||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||||
// TODO: will become a Union[SymIntNodeImpl|int] in the near future
|
|
||||||
return "int";
|
return "int";
|
||||||
}
|
}
|
||||||
static const TypeKind Kind = TypeKind::SymIntType;
|
static const TypeKind Kind = TypeKind::SymIntType;
|
||||||
|
@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) {
|
|||||||
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
|
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
|
||||||
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
|
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TestSymInt, Basic) {
|
|
||||||
Scalar foo;
|
|
||||||
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
|
|
||||||
foo = Scalar(a_impl->toSymInt());
|
|
||||||
ASSERT_EQ(a_impl.use_count(), 2);
|
|
||||||
Scalar bar{foo};
|
|
||||||
ASSERT_EQ(a_impl.use_count(), 3);
|
|
||||||
auto baz = bar;
|
|
||||||
ASSERT_EQ(a_impl.use_count(), 4);
|
|
||||||
auto foo2 = std::move(bar);
|
|
||||||
ASSERT_EQ(a_impl.use_count(), 4);
|
|
||||||
ASSERT_TRUE(foo2.isSymInt());
|
|
||||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
|
||||||
ASSERT_TRUE(bar.isIntegral(false));
|
|
||||||
foo2 = SymInt(4);
|
|
||||||
ASSERT_FALSE(foo2.isSymInt());
|
|
||||||
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
|
|
||||||
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
|
|
||||||
foo2 = foo2;
|
|
||||||
ASSERT_FALSE(foo2.isSymInt());
|
|
||||||
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
|
|
||||||
|
|
||||||
ASSERT_EQ(a_impl.use_count(), 3);
|
|
||||||
|
|
||||||
ASSERT_THROW(foo.to<double>(), c10::Error);
|
|
||||||
|
|
||||||
Scalar int_s = 3;
|
|
||||||
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
@ -958,6 +958,7 @@ libtorch_python_core_sources = [
|
|||||||
"torch/csrc/utils/object_ptr.cpp",
|
"torch/csrc/utils/object_ptr.cpp",
|
||||||
"torch/csrc/utils/python_arg_parser.cpp",
|
"torch/csrc/utils/python_arg_parser.cpp",
|
||||||
"torch/csrc/utils/python_dispatch.cpp",
|
"torch/csrc/utils/python_dispatch.cpp",
|
||||||
|
"torch/csrc/utils/python_symnode.cpp",
|
||||||
"torch/csrc/utils/structseq.cpp",
|
"torch/csrc/utils/structseq.cpp",
|
||||||
"torch/csrc/utils/tensor_apply.cpp",
|
"torch/csrc/utils/tensor_apply.cpp",
|
||||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||||
|
@ -92,8 +92,8 @@ class C10_API Scalar {
|
|||||||
|
|
||||||
SymInt toSymInt() const {
|
SymInt toSymInt() const {
|
||||||
if (Tag::HAS_si == tag) {
|
if (Tag::HAS_si == tag) {
|
||||||
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy(
|
return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
||||||
static_cast<SymIntNodeImpl*>(v.p)));
|
static_cast<SymNodeImpl*>(v.p)));
|
||||||
} else {
|
} else {
|
||||||
return toLong();
|
return toLong();
|
||||||
}
|
}
|
||||||
@ -101,9 +101,8 @@ class C10_API Scalar {
|
|||||||
|
|
||||||
SymFloat toSymFloat() const {
|
SymFloat toSymFloat() const {
|
||||||
if (Tag::HAS_sd == tag) {
|
if (Tag::HAS_sd == tag) {
|
||||||
return c10::SymFloat::toSymFloat(
|
return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
||||||
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy(
|
static_cast<SymNodeImpl*>(v.p)));
|
||||||
static_cast<SymFloatNodeImpl*>(v.p)));
|
|
||||||
} else {
|
} else {
|
||||||
return toDouble();
|
return toDouble();
|
||||||
}
|
}
|
||||||
|
@ -1,32 +1,27 @@
|
|||||||
#include <c10/core/SymFloat.h>
|
#include <c10/core/SymFloat.h>
|
||||||
#include <c10/core/SymFloatNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
SymFloatNode SymFloat::toSymFloatNodeImpl() const {
|
SymNode SymFloat::toSymNodeImpl() const {
|
||||||
TORCH_CHECK(is_symbolic());
|
TORCH_CHECK(is_symbolic());
|
||||||
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
|
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::array<SymFloatNode, 2> normalize_symfloats(
|
static std::array<SymNode, 2> normalize_symfloats(SymFloat a_, SymFloat b_) {
|
||||||
SymFloat a_,
|
SymNode a, b;
|
||||||
SymFloat b_) {
|
|
||||||
SymFloatNode a, b;
|
|
||||||
if (a_.is_symbolic())
|
if (a_.is_symbolic())
|
||||||
a = a_.toSymFloatNodeImpl();
|
a = a_.toSymNodeImpl();
|
||||||
if (b_.is_symbolic())
|
if (b_.is_symbolic())
|
||||||
b = b_.toSymFloatNodeImpl();
|
b = b_.toSymNodeImpl();
|
||||||
|
|
||||||
SymFloatNodeImpl* common = a ? a.get() : b.get();
|
SymNodeImpl* common = a ? a.get() : b.get();
|
||||||
// TODO: technically we need to check that the classes match
|
|
||||||
if (!a) {
|
if (!a) {
|
||||||
a = common->wrap(a_.as_float_unchecked());
|
a = common->wrap_float(a_.as_float_unchecked());
|
||||||
a_.toSymFloat(a); //
|
|
||||||
}
|
}
|
||||||
if (!b) {
|
if (!b) {
|
||||||
b = common->wrap(b_.as_float_unchecked());
|
b = common->wrap_float(b_.as_float_unchecked());
|
||||||
b_.toSymFloat(b);
|
|
||||||
}
|
}
|
||||||
return {a, b};
|
return {a, b};
|
||||||
}
|
}
|
||||||
@ -36,7 +31,7 @@ SymFloat SymFloat::operator+(SymFloat sci) const {
|
|||||||
return SymFloat(data_ + sci.data_);
|
return SymFloat(data_ + sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symfloats(*this, sci);
|
auto res = normalize_symfloats(*this, sci);
|
||||||
return SymFloat::toSymFloat(res[0]->add(res[1]));
|
return SymFloat(res[0]->add(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymFloat SymFloat::operator-(SymFloat sci) const {
|
SymFloat SymFloat::operator-(SymFloat sci) const {
|
||||||
@ -44,7 +39,7 @@ SymFloat SymFloat::operator-(SymFloat sci) const {
|
|||||||
return SymFloat(data_ - sci.data_);
|
return SymFloat(data_ - sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symfloats(*this, sci);
|
auto res = normalize_symfloats(*this, sci);
|
||||||
return SymFloat::toSymFloat(res[0]->sub(res[1]));
|
return SymFloat(res[0]->sub(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymFloat SymFloat::operator*(SymFloat sci) const {
|
SymFloat SymFloat::operator*(SymFloat sci) const {
|
||||||
@ -52,7 +47,7 @@ SymFloat SymFloat::operator*(SymFloat sci) const {
|
|||||||
return SymFloat(data_ * sci.data_);
|
return SymFloat(data_ * sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symfloats(*this, sci);
|
auto res = normalize_symfloats(*this, sci);
|
||||||
return SymFloat::toSymFloat(res[0]->mul(res[1]));
|
return SymFloat(res[0]->mul(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymFloat SymFloat::operator/(SymFloat sci) const {
|
SymFloat SymFloat::operator/(SymFloat sci) const {
|
||||||
@ -60,16 +55,12 @@ SymFloat SymFloat::operator/(SymFloat sci) const {
|
|||||||
return SymFloat(data_ / sci.data_);
|
return SymFloat(data_ / sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symfloats(*this, sci);
|
auto res = normalize_symfloats(*this, sci);
|
||||||
return SymFloat::toSymFloat(res[0]->truediv(res[1]));
|
return SymFloat(res[0]->truediv(res[1]));
|
||||||
}
|
|
||||||
|
|
||||||
c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
|
|
||||||
return c10::SymFloat(std::move(sin_sp));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, SymFloat s) {
|
std::ostream& operator<<(std::ostream& os, SymFloat s) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_symbolic()) {
|
||||||
os << s.toSymFloatNodeImpl()->str();
|
os << s.toSymNodeImpl()->str();
|
||||||
} else {
|
} else {
|
||||||
os << s.as_float_unchecked();
|
os << s.as_float_unchecked();
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/SymFloatNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
@ -14,20 +14,21 @@ namespace c10 {
|
|||||||
class C10_API SymFloat {
|
class C10_API SymFloat {
|
||||||
public:
|
public:
|
||||||
/*implicit*/ SymFloat(double d) : data_(d){};
|
/*implicit*/ SymFloat(double d) : data_(d){};
|
||||||
SymFloat(SymFloatNode ptr)
|
SymFloat(SymNode ptr)
|
||||||
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)){};
|
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
|
||||||
|
TORCH_CHECK(ptr_->is_float());
|
||||||
|
};
|
||||||
SymFloat() : data_(0.0) {}
|
SymFloat() : data_(0.0) {}
|
||||||
|
|
||||||
SymFloatNodeImpl* toSymFloatNodeImplUnowned() const {
|
SymNodeImpl* toSymNodeImplUnowned() const {
|
||||||
return ptr_.get();
|
return ptr_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
SymFloatNodeImpl* release() && {
|
SymNodeImpl* release() && {
|
||||||
return std::move(ptr_).release();
|
return std::move(ptr_).release();
|
||||||
}
|
}
|
||||||
|
|
||||||
SymFloatNode toSymFloatNodeImpl() const;
|
SymNode toSymNodeImpl() const;
|
||||||
static c10::SymFloat toSymFloat(SymFloatNode sin);
|
|
||||||
|
|
||||||
double expect_float() const {
|
double expect_float() const {
|
||||||
TORCH_CHECK(!is_symbolic());
|
TORCH_CHECK(!is_symbolic());
|
||||||
@ -53,7 +54,7 @@ class C10_API SymFloat {
|
|||||||
private:
|
private:
|
||||||
// TODO: optimize to union
|
// TODO: optimize to union
|
||||||
double data_;
|
double data_;
|
||||||
SymFloatNode ptr_;
|
SymNode ptr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);
|
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
#include <c10/core/SymFloat.h>
|
|
||||||
#include <c10/core/SymFloatNodeImpl.h>
|
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
|
||||||
|
|
||||||
namespace c10 {
|
|
||||||
|
|
||||||
c10::SymFloat SymFloatNodeImpl::toSymFloat() {
|
|
||||||
auto sit_sp = SymFloatNode::reclaim_copy(this);
|
|
||||||
return SymFloat::toSymFloat(sit_sp);
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::SymIntNode SymFloatNodeImpl::ceil() {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::SymIntNode SymFloatNodeImpl::floor() {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10
|
|
@ -1,76 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
#include <c10/util/Exception.h>
|
|
||||||
#include <c10/util/intrusive_ptr.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <mutex>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace c10 {
|
|
||||||
|
|
||||||
class SymIntNodeImpl;
|
|
||||||
using SymIntNode = c10::intrusive_ptr<SymIntNodeImpl>;
|
|
||||||
|
|
||||||
class SymFloat;
|
|
||||||
class SymFloatNodeImpl;
|
|
||||||
using SymFloatNode = c10::intrusive_ptr<SymFloatNodeImpl>;
|
|
||||||
|
|
||||||
class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target {
|
|
||||||
public:
|
|
||||||
c10::SymFloat toSymFloat();
|
|
||||||
virtual ~SymFloatNodeImpl(){};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
c10::intrusive_ptr<T> dyn_cast() const {
|
|
||||||
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymFloatNode wrap(double num) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymFloatNode add(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
virtual SymFloatNode sub(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
virtual SymFloatNode mul(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
virtual SymFloatNode truediv(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
virtual SymFloatNode pow(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
virtual SymFloatNode eq(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymFloatNode ne(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymFloatNode gt(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymFloatNode lt(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymFloatNode le(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymFloatNode ge(const SymFloatNode& other) {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
virtual SymIntNode ceil();
|
|
||||||
virtual SymIntNode floor();
|
|
||||||
virtual std::string str() {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
};
|
|
||||||
std::ostream& operator<<(std::ostream& os) {
|
|
||||||
os << str();
|
|
||||||
return os;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace c10
|
|
@ -1,47 +1,46 @@
|
|||||||
#include <c10/core/SymFloat.h>
|
#include <c10/core/SymFloat.h>
|
||||||
#include <c10/core/SymInt.h>
|
#include <c10/core/SymInt.h>
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
static std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) {
|
static std::array<SymNode, 2> normalize_symints(SymInt a_, SymInt b_) {
|
||||||
SymIntNode a, b;
|
SymNode a, b;
|
||||||
if (a_.is_symbolic())
|
if (a_.is_symbolic())
|
||||||
a = a_.toSymIntNodeImpl();
|
a = a_.toSymNodeImpl();
|
||||||
if (b_.is_symbolic())
|
if (b_.is_symbolic())
|
||||||
b = b_.toSymIntNodeImpl();
|
b = b_.toSymNodeImpl();
|
||||||
|
|
||||||
SymIntNodeImpl* common = a ? a.get() : b.get();
|
SymNodeImpl* common = a ? a.get() : b.get();
|
||||||
// TODO: technically we need to check that the classes match
|
// TODO: technically we need to check that the classes match
|
||||||
if (!a) {
|
if (!a) {
|
||||||
a = common->wrap(a_.as_int_unchecked());
|
a = common->wrap_int(a_.as_int_unchecked());
|
||||||
a_.toSymInt(a); //
|
|
||||||
}
|
}
|
||||||
if (!b) {
|
if (!b) {
|
||||||
b = common->wrap(b_.as_int_unchecked());
|
b = common->wrap_int(b_.as_int_unchecked());
|
||||||
b_.toSymInt(b);
|
|
||||||
}
|
}
|
||||||
return {a, b};
|
return {a, b};
|
||||||
}
|
}
|
||||||
|
|
||||||
SymIntNode SymInt::toSymIntNodeImpl() const {
|
SymNode SymInt::toSymNodeImpl() const {
|
||||||
TORCH_CHECK(is_symbolic());
|
TORCH_CHECK(is_symbolic());
|
||||||
return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned());
|
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) {
|
SymInt::SymInt(SymNode sin_sp) {
|
||||||
|
TORCH_CHECK(sin_sp->is_int());
|
||||||
auto ptr = static_cast<uint64_t>(
|
auto ptr = static_cast<uint64_t>(
|
||||||
reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
|
reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
|
||||||
auto rep = (ptr & ~MASK) | IS_SYM;
|
auto rep = (ptr & ~MASK) | IS_SYM;
|
||||||
return c10::SymInt(UNCHECKED, static_cast<int64_t>(rep));
|
data_ = static_cast<int64_t>(rep);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
||||||
if (!is_symbolic()) {
|
if (!is_symbolic()) {
|
||||||
return data_;
|
return data_;
|
||||||
}
|
}
|
||||||
SymIntNode a = toSymIntNodeImpl();
|
SymNode a = toSymNodeImpl();
|
||||||
return a->guard_int(file, line);
|
return a->guard_int(file, line);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,7 +48,7 @@ SymInt::operator SymFloat() const {
|
|||||||
if (!is_symbolic()) {
|
if (!is_symbolic()) {
|
||||||
return SymFloat(double(data_));
|
return SymFloat(double(data_));
|
||||||
}
|
}
|
||||||
return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float());
|
return SymFloat(toSymNodeImpl()->sym_float());
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator+(SymInt sci) const {
|
SymInt SymInt::operator+(SymInt sci) const {
|
||||||
@ -57,7 +56,7 @@ SymInt SymInt::operator+(SymInt sci) const {
|
|||||||
return SymInt(data_ + sci.data_);
|
return SymInt(data_ + sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->add(res[1]));
|
return SymInt(res[0]->add(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator-(SymInt sci) const {
|
SymInt SymInt::operator-(SymInt sci) const {
|
||||||
@ -65,7 +64,7 @@ SymInt SymInt::operator-(SymInt sci) const {
|
|||||||
return SymInt(data_ - sci.data_);
|
return SymInt(data_ - sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->sub(res[1]));
|
return SymInt(res[0]->sub(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator*(SymInt sci) const {
|
SymInt SymInt::operator*(SymInt sci) const {
|
||||||
@ -73,7 +72,7 @@ SymInt SymInt::operator*(SymInt sci) const {
|
|||||||
return SymInt(data_ * sci.data_);
|
return SymInt(data_ * sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->mul(res[1]));
|
return SymInt(res[0]->mul(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator/(SymInt sci) const {
|
SymInt SymInt::operator/(SymInt sci) const {
|
||||||
@ -81,7 +80,7 @@ SymInt SymInt::operator/(SymInt sci) const {
|
|||||||
return SymInt(data_ / sci.data_);
|
return SymInt(data_ / sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->floordiv(res[1]));
|
return SymInt(res[0]->floordiv(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
SymInt SymInt::operator%(SymInt sci) const {
|
SymInt SymInt::operator%(SymInt sci) const {
|
||||||
@ -89,7 +88,7 @@ SymInt SymInt::operator%(SymInt sci) const {
|
|||||||
return SymInt(data_ % sci.data_);
|
return SymInt(data_ % sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->mod(res[1]));
|
return SymInt(res[0]->mod(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SymInt::operator==(SymInt sci) const {
|
bool SymInt::operator==(SymInt sci) const {
|
||||||
@ -141,14 +140,14 @@ SymInt SymInt::min(SymInt sci) const {
|
|||||||
return std::min(data_, sci.data_);
|
return std::min(data_, sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->min(res[1]));
|
return SymInt(res[0]->min(res[1]));
|
||||||
}
|
}
|
||||||
SymInt SymInt::max(SymInt sci) const {
|
SymInt SymInt::max(SymInt sci) const {
|
||||||
if (!is_symbolic() && !sci.is_symbolic()) {
|
if (!is_symbolic() && !sci.is_symbolic()) {
|
||||||
return std::max(data_, sci.data_);
|
return std::max(data_, sci.data_);
|
||||||
}
|
}
|
||||||
auto res = normalize_symints(*this, sci);
|
auto res = normalize_symints(*this, sci);
|
||||||
return SymInt::toSymInt(res[0]->max(res[1]));
|
return SymInt(res[0]->max(res[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SymInt::operator*=(SymInt sci) {
|
void SymInt::operator*=(SymInt sci) {
|
||||||
@ -193,7 +192,7 @@ SymInt SymInt::operator*(int64_t sci) const {
|
|||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, SymInt s) {
|
std::ostream& operator<<(std::ostream& os, SymInt s) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_symbolic()) {
|
||||||
os << s.toSymIntNodeImpl()->str();
|
os << s.toSymNodeImpl()->str();
|
||||||
} else {
|
} else {
|
||||||
os << s.as_int_unchecked();
|
os << s.as_int_unchecked();
|
||||||
}
|
}
|
||||||
@ -202,7 +201,7 @@ std::ostream& operator<<(std::ostream& os, SymInt s) {
|
|||||||
|
|
||||||
SymInt operator-(SymInt s) {
|
SymInt operator-(SymInt s) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_symbolic()) {
|
||||||
return SymInt::toSymInt(s.toSymIntNodeImpl()->neg());
|
return SymInt(s.toSymNodeImpl()->neg());
|
||||||
} else {
|
} else {
|
||||||
return SymInt(-s.as_int_unchecked());
|
return SymInt(-s.as_int_unchecked());
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
@ -12,24 +12,19 @@ namespace c10 {
|
|||||||
|
|
||||||
class SymFloat;
|
class SymFloat;
|
||||||
|
|
||||||
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
|
// SymInt represents either a regular int64_t, or a symbolic integer
|
||||||
// represent concrete dimension values.
|
// (represented in a type erased way as SymNode). The intention is for SymInt
|
||||||
|
// to represent symbolic sizes that arise when doing shape computation in
|
||||||
|
// operator kernels. This allows for tracing through programs without baking in
|
||||||
|
// concrete sizes into kernel calls.
|
||||||
//
|
//
|
||||||
// `SymInt` is also a data type in Pytorch that can be used in function schemas
|
// SymInt has an API equivalent to int64_t. In particular, it is a value type.
|
||||||
// to enable tracing.
|
// Internally, SymInt is represented in a clever packed way, so that it only
|
||||||
|
// occupies one word of space; but morally, it is a union between an int64_t
|
||||||
|
// and an intrusive pointer to SymNodeImpl.
|
||||||
//
|
//
|
||||||
// `SymInt` is introduced to enable tracing arithmetic
|
// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
|
||||||
// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will
|
// is_int() returns true
|
||||||
// allow LTC and AOTAutograd representing dynamic shapes in expression graphs
|
|
||||||
// faithfully without baking in concrete dimension values.
|
|
||||||
//
|
|
||||||
// To trace the operations, SymInt will overload arithmetic operators (e.g. +,
|
|
||||||
// -, *) and will provide overloads taking SymInt for commonly used math
|
|
||||||
// functions.
|
|
||||||
//
|
|
||||||
// SymInt will be extenteded to represent a union structure Union[int64_t,
|
|
||||||
// SymIntNodeImpl*] which will be implemented as a single packed int64_t field
|
|
||||||
// named data_.
|
|
||||||
|
|
||||||
class C10_API SymInt {
|
class C10_API SymInt {
|
||||||
public:
|
public:
|
||||||
@ -44,6 +39,7 @@ class C10_API SymInt {
|
|||||||
TORCH_CHECK(!is_symbolic());
|
TORCH_CHECK(!is_symbolic());
|
||||||
};
|
};
|
||||||
SymInt() : data_(0) {}
|
SymInt() : data_(0) {}
|
||||||
|
SymInt(SymNode n);
|
||||||
|
|
||||||
// unchecked c-tor accepting raw `data_`
|
// unchecked c-tor accepting raw `data_`
|
||||||
// One appropriate use for this is when you are constructing a symint
|
// One appropriate use for this is when you are constructing a symint
|
||||||
@ -55,7 +51,7 @@ class C10_API SymInt {
|
|||||||
// temporary and then use the move constructor/assignment
|
// temporary and then use the move constructor/assignment
|
||||||
SymInt(const SymInt& s) : data_(0) {
|
SymInt(const SymInt& s) : data_(0) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_symbolic()) {
|
||||||
*this = SymInt::toSymInt(s.toSymIntNodeImpl());
|
*this = SymInt(s.toSymNodeImpl());
|
||||||
} else {
|
} else {
|
||||||
data_ = s.data_;
|
data_ = s.data_;
|
||||||
}
|
}
|
||||||
@ -67,7 +63,7 @@ class C10_API SymInt {
|
|||||||
SymInt& operator=(const SymInt& s) {
|
SymInt& operator=(const SymInt& s) {
|
||||||
if (this != &s) {
|
if (this != &s) {
|
||||||
if (s.is_symbolic()) {
|
if (s.is_symbolic()) {
|
||||||
*this = SymInt::toSymInt(s.toSymIntNodeImpl());
|
*this = SymInt(s.toSymNodeImpl());
|
||||||
} else {
|
} else {
|
||||||
data_ = s.data_;
|
data_ = s.data_;
|
||||||
}
|
}
|
||||||
@ -76,7 +72,7 @@ class C10_API SymInt {
|
|||||||
}
|
}
|
||||||
SymInt& operator=(SymInt&& s) {
|
SymInt& operator=(SymInt&& s) {
|
||||||
if (this != &s) {
|
if (this != &s) {
|
||||||
release_(); // release the current SymIntNode if any
|
release_(); // release the current SymNode if any
|
||||||
data_ = s.data_;
|
data_ = s.data_;
|
||||||
if (s.is_symbolic())
|
if (s.is_symbolic())
|
||||||
s.data_ = 0;
|
s.data_ = 0;
|
||||||
@ -86,31 +82,31 @@ class C10_API SymInt {
|
|||||||
|
|
||||||
SymInt clone() const {
|
SymInt clone() const {
|
||||||
if (is_symbolic()) {
|
if (is_symbolic()) {
|
||||||
return toSymIntNodeImplUnowned()->clone()->toSymInt();
|
return SymInt(toSymNodeImplUnowned()->clone());
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
SymIntNodeImpl* toSymIntNodeImplUnowned() const {
|
SymNodeImpl* toSymNodeImplUnowned() const {
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
|
||||||
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
|
||||||
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
uint64_t sign_bit_mask = 1ULL << (62 - 1);
|
||||||
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
|
||||||
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
|
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
|
||||||
return static_cast<SymIntNodeImpl*>(
|
return static_cast<SymNodeImpl*>(
|
||||||
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
|
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void release_() {
|
void release_() {
|
||||||
if (is_symbolic()) {
|
if (is_symbolic()) {
|
||||||
SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal
|
SymNode::reclaim(toSymNodeImplUnowned()); // steal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SymIntNodeImpl* release() && {
|
SymNodeImpl* release() && {
|
||||||
#ifndef C10_MOBILE
|
#ifndef C10_MOBILE
|
||||||
TORCH_INTERNAL_ASSERT(is_symbolic());
|
TORCH_INTERNAL_ASSERT(is_symbolic());
|
||||||
auto* r = toSymIntNodeImplUnowned();
|
auto* r = toSymNodeImplUnowned();
|
||||||
data_ = 0; // transfer ownership
|
data_ = 0; // transfer ownership
|
||||||
return r;
|
return r;
|
||||||
#else
|
#else
|
||||||
@ -118,8 +114,7 @@ class C10_API SymInt {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
SymIntNode toSymIntNodeImpl() const;
|
SymNode toSymNodeImpl() const;
|
||||||
static c10::SymInt toSymInt(SymIntNode sin);
|
|
||||||
|
|
||||||
~SymInt() {
|
~SymInt() {
|
||||||
release_();
|
release_();
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
#include <c10/core/SymInt.h>
|
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
|
||||||
|
|
||||||
namespace c10 {
|
|
||||||
|
|
||||||
c10::SymInt SymIntNodeImpl::toSymInt() {
|
|
||||||
auto sit_sp = SymIntNode::reclaim_copy(this);
|
|
||||||
return SymInt::toSymInt(sit_sp);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10
|
|
3
c10/core/SymNodeImpl.cpp
Normal file
3
c10/core/SymNodeImpl.cpp
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
||||||
|
namespace c10 {} // namespace c10
|
@ -1,6 +1,5 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/SymFloatNodeImpl.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
@ -10,13 +9,12 @@
|
|||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
class SymInt;
|
class SymNodeImpl;
|
||||||
class SymIntNodeImpl;
|
using SymNode = c10::intrusive_ptr<SymNodeImpl>;
|
||||||
|
|
||||||
class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
|
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||||
public:
|
public:
|
||||||
c10::SymInt toSymInt();
|
virtual ~SymNodeImpl(){};
|
||||||
virtual ~SymIntNodeImpl(){};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
c10::intrusive_ptr<T> dyn_cast() const {
|
c10::intrusive_ptr<T> dyn_cast() const {
|
||||||
@ -24,66 +22,87 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// these could be pure virtual when we implement LTC versions
|
// these could be pure virtual when we implement LTC versions
|
||||||
virtual SymIntNode add(const SymIntNode& other) {
|
virtual bool is_int() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode sub(const SymIntNode& other) {
|
virtual bool is_float() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode mul(const SymIntNode& other) {
|
virtual SymNode add(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymFloatNode truediv(const SymIntNode& other) {
|
virtual SymNode sub(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode floordiv(const SymIntNode& other) {
|
virtual SymNode mul(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode mod(const SymIntNode& other) {
|
virtual SymNode truediv(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode eq(const SymIntNode& other) {
|
virtual SymNode pow(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode ne(const SymIntNode& other) {
|
virtual SymNode floordiv(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode gt(const SymIntNode& other) {
|
virtual SymNode mod(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode lt(const SymIntNode& other) {
|
virtual SymNode eq(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode le(const SymIntNode& other) {
|
virtual SymNode ne(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode ge(const SymIntNode& other) {
|
virtual SymNode gt(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode ceil() {
|
virtual SymNode lt(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode neg() {
|
virtual SymNode le(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode min(const SymIntNode& other) {
|
virtual SymNode ge(const SymNode& other) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode max(const SymIntNode& other) {
|
virtual SymNode ceil() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymIntNode clone() {
|
virtual SymNode floor() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual SymFloatNode sym_float() {
|
virtual SymNode neg() {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
};
|
||||||
|
virtual SymNode min(const SymNode& other) {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
};
|
||||||
|
virtual SymNode max(const SymNode& other) {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
};
|
||||||
|
virtual SymNode clone() {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
};
|
||||||
|
virtual SymNode sym_int() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
}
|
}
|
||||||
virtual SymIntNode wrap(int64_t num) {
|
virtual SymNode sym_float() {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
}
|
||||||
|
virtual SymNode wrap_int(int64_t num) {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
};
|
||||||
|
virtual SymNode wrap_float(double num) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
virtual int64_t guard_int(const char* file, int64_t line) {
|
virtual int64_t guard_int(const char* file, int64_t line) {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
||||||
|
virtual double guard_float(const char* file, int64_t line) {
|
||||||
|
TORCH_CHECK(false, "NYI");
|
||||||
|
};
|
||||||
virtual int64_t int_() {
|
virtual int64_t int_() {
|
||||||
TORCH_CHECK(false, "NYI");
|
TORCH_CHECK(false, "NYI");
|
||||||
};
|
};
|
@ -1,7 +1,7 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <c10/core/SymInt.h>
|
#include <c10/core/SymInt.h>
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
||||||
using namespace c10;
|
using namespace c10;
|
||||||
#ifndef C10_MOBILE
|
#ifndef C10_MOBILE
|
||||||
@ -20,12 +20,6 @@ TEST(SymIntTest, ConcreteInts) {
|
|||||||
check(-4611686018427387904LL);
|
check(-4611686018427387904LL);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SymIntTest, AddNode) {
|
|
||||||
auto n = c10::make_intrusive<SymIntNodeImpl>();
|
|
||||||
auto i = n->toSymInt();
|
|
||||||
EXPECT_TRUE(i.is_symbolic());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(SymIntTest, CheckRange) {
|
TEST(SymIntTest, CheckRange) {
|
||||||
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
|
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
|
||||||
}
|
}
|
||||||
|
@ -335,8 +335,8 @@ coverage_ignore_classes = [
|
|||||||
"Quantize",
|
"Quantize",
|
||||||
# torch.utils.backcompat
|
# torch.utils.backcompat
|
||||||
"Warning",
|
"Warning",
|
||||||
"SymIntNode",
|
"SymInt",
|
||||||
"SymFloatNode",
|
"SymFloat",
|
||||||
]
|
]
|
||||||
|
|
||||||
# The suffix(es) of source filenames.
|
# The suffix(es) of source filenames.
|
||||||
|
@ -605,7 +605,7 @@ class PytreeThunk:
|
|||||||
return x
|
return x
|
||||||
return pytree.tree_unflatten(x, self.spec)
|
return pytree.tree_unflatten(x, self.spec)
|
||||||
|
|
||||||
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymIntNode, torch.SymFloatNode]
|
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymInt, torch.SymFloat]
|
||||||
|
|
||||||
|
|
||||||
def aot_function(
|
def aot_function(
|
||||||
|
@ -209,7 +209,7 @@ def _tensor_nbytes(numel, dtype):
|
|||||||
|
|
||||||
def _size_of(node: fx.Node) -> int:
|
def _size_of(node: fx.Node) -> int:
|
||||||
def to_size_hint(s):
|
def to_size_hint(s):
|
||||||
if isinstance(s, torch.SymIntNode):
|
if isinstance(s, torch.SymInt):
|
||||||
py_s = s.get_pyobj()
|
py_s = s.get_pyobj()
|
||||||
return py_s.shape_env.size_hint(py_s.expr)
|
return py_s.shape_env.size_hint(py_s.expr)
|
||||||
assert isinstance(s, int)
|
assert isinstance(s, int)
|
||||||
|
@ -18,6 +18,8 @@ cond = PyOperator('cond')
|
|||||||
|
|
||||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||||
def _unwrap_proxy(e):
|
def _unwrap_proxy(e):
|
||||||
|
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
|
||||||
|
return e
|
||||||
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
|
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
|
||||||
|
|
||||||
assert isinstance(operands, list), "Cond operands must be a list of tensors"
|
assert isinstance(operands, list), "Cond operands must be a list of tensors"
|
||||||
|
@ -1447,35 +1447,29 @@ TEST(TestSymInt, AddSymbolicInt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifndef C10_MOBILE
|
#ifndef C10_MOBILE
|
||||||
TEST(TestSymInt, TestIntrusive) {
|
class TestSymNodeImpl : public c10::SymNodeImpl {
|
||||||
auto a = c10::make_intrusive<c10::SymIntNodeImpl>();
|
|
||||||
auto b = c10::make_intrusive<c10::SymIntNodeImpl>();
|
|
||||||
ASSERT_EQ(a.use_count(), 1);
|
|
||||||
ASSERT_EQ(b.use_count(), 1);
|
|
||||||
auto as = a->toSymInt();
|
|
||||||
auto bs = b->toSymInt();
|
|
||||||
ASSERT_EQ(a.use_count(), 2);
|
|
||||||
ASSERT_EQ(b.use_count(), 2);
|
|
||||||
as = bs;
|
|
||||||
ASSERT_EQ(a.use_count(), 1);
|
|
||||||
ASSERT_EQ(b.use_count(), 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
|
|
||||||
public:
|
public:
|
||||||
TestSymIntNodeImpl(int64_t i) : i_(i) {}
|
explicit TestSymNodeImpl(int64_t i) : i_(i) {}
|
||||||
|
|
||||||
|
bool is_int() override {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_float() override {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
bool bool_() override {
|
bool bool_() override {
|
||||||
return static_cast<bool>(i_);
|
return static_cast<bool>(i_);
|
||||||
};
|
};
|
||||||
|
|
||||||
#define OPDEF3(NAME, OP, RET) \
|
#define OPDEF3(NAME, OP, RET) \
|
||||||
RET NAME(const c10::SymIntNode& other) override { \
|
RET NAME(const c10::SymNode& other) override { \
|
||||||
return make_intrusive<TestSymIntNodeImpl>( \
|
return make_intrusive<TestSymNodeImpl>( \
|
||||||
this->i_ OP dynamic_cast<TestSymIntNodeImpl*>(other.get())->i_); \
|
this->i_ OP dynamic_cast<TestSymNodeImpl*>(other.get())->i_); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymIntNode)
|
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode)
|
||||||
OPDEF2(add, +)
|
OPDEF2(add, +)
|
||||||
OPDEF2(sub, -)
|
OPDEF2(sub, -)
|
||||||
OPDEF2(mul, *)
|
OPDEF2(mul, *)
|
||||||
@ -1494,17 +1488,19 @@ class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
|
|||||||
int64_t i_;
|
int64_t i_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(TestSymInt, TestSymIntToSymIntNodeDispatch) {
|
TEST(TestSymInt, TestSymIntToSymNodeDispatch) {
|
||||||
auto get = [](c10::SymInt si) {
|
auto get = [](c10::SymInt si) {
|
||||||
auto node = si.toSymIntNodeImpl();
|
auto node = si.toSymNodeImpl();
|
||||||
return dynamic_cast<TestSymIntNodeImpl*>(node.get())->i_;
|
return dynamic_cast<TestSymNodeImpl*>(node.get())->i_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<int64_t> inputs{0, 1, -1, 4, -4, 777, -777};
|
std::vector<int64_t> inputs{0, 1, -1, 4, -4, 777, -777};
|
||||||
for (auto i : inputs) {
|
for (auto i : inputs) {
|
||||||
for (auto j : inputs) {
|
for (auto j : inputs) {
|
||||||
auto a = c10::make_intrusive<TestSymIntNodeImpl>(i)->toSymInt();
|
auto a = c10::SymInt(
|
||||||
auto b = c10::make_intrusive<TestSymIntNodeImpl>(j)->toSymInt();
|
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(i)));
|
||||||
|
auto b = c10::SymInt(
|
||||||
|
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(j)));
|
||||||
ASSERT_EQ(get(a + b), i + j);
|
ASSERT_EQ(get(a + b), i + j);
|
||||||
ASSERT_EQ(get(a - b), i - j);
|
ASSERT_EQ(get(a - b), i - j);
|
||||||
ASSERT_EQ(get(a * b), i * j);
|
ASSERT_EQ(get(a * b), i * j);
|
||||||
|
@ -12,8 +12,9 @@ import itertools
|
|||||||
import io
|
import io
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
from torch import SymInt
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
@ -116,9 +117,6 @@ def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
|
|||||||
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
||||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
||||||
|
|
||||||
|
|
||||||
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
|
|
||||||
|
|
||||||
def create_symint(shape_env, i):
|
def create_symint(shape_env, i):
|
||||||
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
||||||
|
|
||||||
@ -156,8 +154,8 @@ class TestPySymInt(TestCase):
|
|||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||||
|
|
||||||
self.assertTrue(not isinstance(x.shape[0], PySymInt))
|
self.assertTrue(not isinstance(x.shape[0], SymNode))
|
||||||
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
|
self.assertTrue(isinstance(x.shape[0], SymInt))
|
||||||
|
|
||||||
self.assertTrue(x.shape[0] == 5)
|
self.assertTrue(x.shape[0] == 5)
|
||||||
self.assertTrue(x.shape[1] == 4)
|
self.assertTrue(x.shape[1] == 4)
|
||||||
@ -165,17 +163,17 @@ class TestPySymInt(TestCase):
|
|||||||
|
|
||||||
self.assertTrue(x.size()[0], 5)
|
self.assertTrue(x.size()[0], 5)
|
||||||
self.assertTrue(x.size()[1], 4)
|
self.assertTrue(x.size()[1], 4)
|
||||||
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
|
self.assertTrue(isinstance(x.size()[1], SymInt))
|
||||||
self.assertTrue(x.size()[2] == 3)
|
self.assertTrue(x.size()[2] == 3)
|
||||||
|
|
||||||
self.assertTrue(x.size(0) == 5)
|
self.assertTrue(x.size(0) == 5)
|
||||||
self.assertTrue(x.size(1) == 4)
|
self.assertTrue(x.size(1) == 4)
|
||||||
self.assertTrue(x.size(2) == 3)
|
self.assertTrue(x.size(2) == 3)
|
||||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
self.assertTrue(isinstance(x.size(2), SymInt))
|
||||||
|
|
||||||
offset = create_symint(shape_env, 2)
|
offset = create_symint(shape_env, 2)
|
||||||
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
||||||
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
|
self.assertTrue(isinstance(y.storage_offset(), SymInt))
|
||||||
self.assertTrue(y.storage_offset() == 2)
|
self.assertTrue(y.storage_offset() == 2)
|
||||||
|
|
||||||
offset = 2
|
offset = 2
|
||||||
@ -267,7 +265,7 @@ class TestPySymInt(TestCase):
|
|||||||
def test_stride(self):
|
def test_stride(self):
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
||||||
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
|
self.assertIsInstance(x.stride()[0], SymInt)
|
||||||
|
|
||||||
@skipIfNoSympy
|
@skipIfNoSympy
|
||||||
def test_size_expressions(self):
|
def test_size_expressions(self):
|
||||||
@ -290,7 +288,7 @@ class TestPySymInt(TestCase):
|
|||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||||
r = sym_float(x.shape[0])
|
r = sym_float(x.shape[0])
|
||||||
self.assertTrue(isinstance(r, torch.SymFloatNode))
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||||
|
|
||||||
@skipIfNoSympy
|
@skipIfNoSympy
|
||||||
def test_aten_ops(self):
|
def test_aten_ops(self):
|
||||||
@ -320,13 +318,13 @@ class TestPySymInt(TestCase):
|
|||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
a0 = create_symint(shape_env, 2)
|
a0 = create_symint(shape_env, 2)
|
||||||
r = torch.empty(a0, device='meta')
|
r = torch.empty(a0, device='meta')
|
||||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
self.assertIsInstance(r.shape[0], SymInt)
|
||||||
|
|
||||||
@skipIfNoSympy
|
@skipIfNoSympy
|
||||||
def test_guard_int(self):
|
def test_guard_int(self):
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
a0 = create_symint(shape_env, 2)
|
a0 = create_symint(shape_env, 2)
|
||||||
self.assertEqual(a0.guard_int(), 2)
|
self.assertEqual(guard_int(a0), 2)
|
||||||
self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)")
|
self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)")
|
||||||
|
|
||||||
@skipIfNoSympy
|
@skipIfNoSympy
|
||||||
@ -347,7 +345,9 @@ class TestPySymInt(TestCase):
|
|||||||
assert func == torch.ops.aten.add.Tensor
|
assert func == torch.ops.aten.add.Tensor
|
||||||
|
|
||||||
nonlocal sym_int_encountered
|
nonlocal sym_int_encountered
|
||||||
sym_int_encountered = kwargs["alpha"] is a0
|
# WARNING: do not do identity tests on the outer
|
||||||
|
# SymInt/SymFloat, they are NOT STABLE
|
||||||
|
sym_int_encountered = kwargs["alpha"].node is a0.node
|
||||||
kwargs["alpha"] = 0
|
kwargs["alpha"] = 0
|
||||||
return func(*args)
|
return func(*args)
|
||||||
|
|
||||||
|
@ -1,391 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Owner(s): ["oncall: jit"]
|
|
||||||
|
|
||||||
from torch._C import _disabled_torch_function_impl
|
|
||||||
import torch.fx
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
|
|
||||||
import unittest
|
|
||||||
import torch
|
|
||||||
import operator
|
|
||||||
import itertools
|
|
||||||
import io
|
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
|
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
|
||||||
|
|
||||||
try:
|
|
||||||
import sympy
|
|
||||||
HAS_SYMPY = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_SYMPY = False
|
|
||||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
|
||||||
|
|
||||||
|
|
||||||
meta_funcs = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_meta(op):
|
|
||||||
def decorator(f):
|
|
||||||
def add_func(op):
|
|
||||||
meta_funcs[op] = f
|
|
||||||
tree_map(add_func, op)
|
|
||||||
return f
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
|
||||||
def binary_meta(a, b):
|
|
||||||
return a.new_empty(a.shape)
|
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.cat.default)
|
|
||||||
def cat_meta(tensors, dim=0):
|
|
||||||
concat_length = 0
|
|
||||||
shape = tensors[0].shape
|
|
||||||
for tensor in tensors:
|
|
||||||
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
|
||||||
if idx == dim:
|
|
||||||
concat_length = concat_length + length
|
|
||||||
else:
|
|
||||||
assert length == common_length
|
|
||||||
new_shape = list(shape)
|
|
||||||
new_shape[dim] = concat_length
|
|
||||||
return tensors[0].new_empty(new_shape)
|
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten.narrow_copy.default])
|
|
||||||
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
|
||||||
shape = []
|
|
||||||
for i, x in enumerate(a.shape):
|
|
||||||
if i == dim:
|
|
||||||
shape.append(length)
|
|
||||||
else:
|
|
||||||
shape.append(x)
|
|
||||||
return a.new_empty(tuple(shape))
|
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten.expand.default])
|
|
||||||
def expand_symint_meta(a, size, implicit=False):
|
|
||||||
return a.new_empty(size)
|
|
||||||
|
|
||||||
|
|
||||||
def create_contiguous(shape):
|
|
||||||
strides = [1]
|
|
||||||
for dim in reversed(shape[:-1]):
|
|
||||||
strides.append(dim * strides[-1])
|
|
||||||
return list(reversed(strides))
|
|
||||||
|
|
||||||
|
|
||||||
class FakeSymbolicTensor(torch.Tensor):
|
|
||||||
@staticmethod
|
|
||||||
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
|
|
||||||
# TODO: this is wrong in general
|
|
||||||
sym_stride = create_contiguous(sym_shape)
|
|
||||||
r = torch.Tensor._make_wrapper_subclass(
|
|
||||||
cls, sym_shape,
|
|
||||||
sym_stride, storage_offset,
|
|
||||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
return r
|
|
||||||
|
|
||||||
__torch_function__ = _disabled_torch_function_impl
|
|
||||||
|
|
||||||
def new_empty(self, shape):
|
|
||||||
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
|
||||||
if func_overload in meta_funcs:
|
|
||||||
return meta_funcs[func_overload](*args, **kwargs)
|
|
||||||
|
|
||||||
if func_overload == torch.ops.aten.new_empty.default:
|
|
||||||
self = args[0]
|
|
||||||
shape = args[1]
|
|
||||||
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
|
|
||||||
|
|
||||||
raise RuntimeError(f"operator {func_overload} not supported")
|
|
||||||
|
|
||||||
|
|
||||||
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
|
|
||||||
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
|
|
||||||
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
|
|
||||||
|
|
||||||
|
|
||||||
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
|
|
||||||
|
|
||||||
def create_symint(shape_env, i):
|
|
||||||
return shape_env.create_symintnode(shape_env.create_symbol(i))
|
|
||||||
|
|
||||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
|
||||||
class TestPySymInt(TestCase):
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_arith_ops(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
symints = []
|
|
||||||
for i in range(2, 5):
|
|
||||||
symints.append((i, create_symint(shape_env, i)))
|
|
||||||
|
|
||||||
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
|
|
||||||
|
|
||||||
for op in ops:
|
|
||||||
for args in itertools.permutations(symints, 2):
|
|
||||||
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
|
|
||||||
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
|
|
||||||
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_reverse_arith_ops(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
|
|
||||||
a = create_symint(shape_env, 2)
|
|
||||||
self.assertTrue(5 // a == 5 // 2)
|
|
||||||
|
|
||||||
a = create_symint(shape_env, 2)
|
|
||||||
self.assertTrue(5 * a == 5 * 2)
|
|
||||||
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_roundtrip(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
||||||
|
|
||||||
self.assertTrue(not isinstance(x.shape[0], PySymInt))
|
|
||||||
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
|
|
||||||
|
|
||||||
self.assertTrue(x.shape[0] == 5)
|
|
||||||
self.assertTrue(x.shape[1] == 4)
|
|
||||||
self.assertTrue(x.shape[2], 3)
|
|
||||||
|
|
||||||
self.assertTrue(x.size()[0], 5)
|
|
||||||
self.assertTrue(x.size()[1], 4)
|
|
||||||
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
|
|
||||||
self.assertTrue(x.size()[2] == 3)
|
|
||||||
|
|
||||||
self.assertTrue(x.size(0) == 5)
|
|
||||||
self.assertTrue(x.size(1) == 4)
|
|
||||||
self.assertTrue(x.size(2) == 3)
|
|
||||||
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
|
|
||||||
|
|
||||||
offset = create_symint(shape_env, 2)
|
|
||||||
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
|
|
||||||
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
|
|
||||||
self.assertTrue(y.storage_offset() == 2)
|
|
||||||
|
|
||||||
offset = 2
|
|
||||||
z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset)
|
|
||||||
self.assertTrue(isinstance(z.storage_offset(), int))
|
|
||||||
self.assertTrue(z.storage_offset() == 2)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_binary(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
||||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
|
||||||
|
|
||||||
z = x + y
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
# broadcasting
|
|
||||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
|
||||||
z = x + y
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_symint_args(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
||||||
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
|
||||||
LAST_DIM = 2
|
|
||||||
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
|
||||||
self.assertTrue(z.shape[2] == y.shape[2])
|
|
||||||
|
|
||||||
# arithmetic expr with two symints
|
|
||||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
|
||||||
self.assertTrue(z.shape[2] == 2)
|
|
||||||
|
|
||||||
# arithmetic expr with a symint and python int
|
|
||||||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
|
||||||
self.assertTrue(z.shape[2] == 2)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_symint_vargs(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
||||||
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
|
||||||
|
|
||||||
# varargs
|
|
||||||
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
# shape list
|
|
||||||
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
# mixed python symints and ints
|
|
||||||
z = y.expand(x.shape[0], y.shape[1], 3)
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
# mixed python symints and ints in a list
|
|
||||||
z = y.expand((x.shape[0], y.shape[1], 3))
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
# mixed python symints and ints
|
|
||||||
z = y.expand(5, y.shape[1], x.shape[2])
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
# mixed python ints and symints in a list
|
|
||||||
z = y.expand((5, y.shape[1], x.shape[2]))
|
|
||||||
self.assertTrue(z.shape[0] == 5)
|
|
||||||
self.assertTrue(z.shape[1] == 4)
|
|
||||||
self.assertTrue(z.shape[2] == 3)
|
|
||||||
|
|
||||||
z = y.expand((y.shape[1],))
|
|
||||||
z = y.expand(y.shape[1])
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_stride(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
|
||||||
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_size_expressions(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
||||||
expand_x = x.expand(x.shape[0], x.shape[0])
|
|
||||||
if expand_x.shape[0] > 3:
|
|
||||||
result = expand_x + expand_x
|
|
||||||
else:
|
|
||||||
result = expand_x + expand_x
|
|
||||||
|
|
||||||
gt_op = shape_env.guards[0][0]
|
|
||||||
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
|
||||||
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
|
||||||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
|
||||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_int_to_float(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
||||||
r = sym_float(x.shape[0])
|
|
||||||
self.assertTrue(isinstance(r, torch.SymFloatNode))
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_aten_ops(self):
|
|
||||||
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
||||||
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
|
||||||
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
||||||
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
|
||||||
|
|
||||||
def test_fx_trace_intlist(self):
|
|
||||||
class CustomModule(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
|
||||||
|
|
||||||
m = CustomModule()
|
|
||||||
x = torch.rand(1, 3, 4, 4)
|
|
||||||
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
|
||||||
# tuple of ints, not tuple
|
|
||||||
torch.fx.symbolic_trace(m)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_meta_symint(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
a0 = create_symint(shape_env, 2)
|
|
||||||
r = torch.empty(a0, device='meta')
|
|
||||||
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_guard_int(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
a0 = create_symint(shape_env, 2)
|
|
||||||
self.assertEqual(a0.guard_int(), 2)
|
|
||||||
self.assertEqual(str(shape_env.guards[0][0]), "s0")
|
|
||||||
self.assertEqual(shape_env.guards[0][1], 2)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_int_conversion(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
a0 = create_symint(shape_env, 2)
|
|
||||||
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
def test_symint_as_scalar(self):
|
|
||||||
shape_env = ShapeEnv()
|
|
||||||
a0 = create_symint(shape_env, 2)
|
|
||||||
|
|
||||||
sym_int_encountered = False
|
|
||||||
|
|
||||||
class TestSymInt(TorchDispatchMode):
|
|
||||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
||||||
assert func == torch.ops.aten.add.Tensor
|
|
||||||
|
|
||||||
nonlocal sym_int_encountered
|
|
||||||
sym_int_encountered = kwargs["alpha"] is a0
|
|
||||||
kwargs["alpha"] = 0
|
|
||||||
return func(*args)
|
|
||||||
|
|
||||||
x = torch.rand([4, 4])
|
|
||||||
with TestSymInt():
|
|
||||||
y = torch.add(x, x, alpha=a0)
|
|
||||||
|
|
||||||
self.assertTrue(sym_int_encountered)
|
|
||||||
|
|
||||||
@skipIfNoSympy
|
|
||||||
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
|
|
||||||
def test_print_readable_with_symints(self, mock_stdout):
|
|
||||||
def f(a, b):
|
|
||||||
dim0 = a.shape[0] + b.shape[0]
|
|
||||||
dim1 = a.shape[1] + b.shape[1]
|
|
||||||
d = a.new_empty(dim0, dim1)
|
|
||||||
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
|
|
||||||
return d
|
|
||||||
|
|
||||||
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
|
|
||||||
fx_g.print_readable()
|
|
||||||
|
|
||||||
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
|
|
||||||
class f(torch.nn.Module):
|
|
||||||
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
|
|
||||||
# No stacktrace found for following nodes
|
|
||||||
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
|
|
||||||
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
|
|
||||||
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
|
|
||||||
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
|
|
||||||
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
|
|
||||||
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
|
|
||||||
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
|
|
||||||
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
|
||||||
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
|
|
||||||
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
|
|
||||||
return (getitem, getitem_1)""") # noqa: B950
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run_tests()
|
|
@ -875,8 +875,7 @@ def forward(self, a_1):
|
|||||||
self.assertExpectedInline(r, """\
|
self.assertExpectedInline(r, """\
|
||||||
def forward(self, a_1):
|
def forward(self, a_1):
|
||||||
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
sym_size = torch.ops.aten.sym_size(a_1, 0)
|
||||||
sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None
|
pow_1 = sym_size ** 0.5; sym_size = None
|
||||||
pow_1 = sym_float ** 0.5; sym_float = None
|
|
||||||
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
|
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
|
||||||
return div""")
|
return div""")
|
||||||
|
|
||||||
@ -949,7 +948,7 @@ def forward(self, a_1):
|
|||||||
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
|
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
|
||||||
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
|
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
|
||||||
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
|
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
|
||||||
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr)
|
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].node.expr)
|
||||||
|
|
||||||
def test_metadata_fresh(self):
|
def test_metadata_fresh(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
|
@ -207,8 +207,8 @@ class TestPublicBindings(TestCase):
|
|||||||
"StreamObjType",
|
"StreamObjType",
|
||||||
"StringType",
|
"StringType",
|
||||||
"SUM",
|
"SUM",
|
||||||
"SymFloatNode",
|
"SymFloat",
|
||||||
"SymIntNode",
|
"SymInt",
|
||||||
"TensorType",
|
"TensorType",
|
||||||
"ThroughputBenchmark",
|
"ThroughputBenchmark",
|
||||||
"TracingState",
|
"TracingState",
|
||||||
|
@ -291,7 +291,7 @@ PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
|
|||||||
for (auto i : c10::irange(prop.size())) {
|
for (auto i : c10::irange(prop.size())) {
|
||||||
auto si = prop[i];
|
auto si = prop[i];
|
||||||
if (si.is_symbolic()) {
|
if (si.is_symbolic()) {
|
||||||
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
|
auto py_symint = py::cast(si).release().ptr();
|
||||||
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
|
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
|
||||||
} else {
|
} else {
|
||||||
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked()));
|
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked()));
|
||||||
@ -313,7 +313,7 @@ return PyLong_FromUnsignedLong((int64_t) prop);
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
GETTER_BODY_SYMINT = """\
|
GETTER_BODY_SYMINT = """\
|
||||||
return prop.is_symbolic() ? py::cast(prop.toSymIntNodeImpl()).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
|
return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
|
||||||
"""
|
"""
|
||||||
|
|
||||||
GETTER_BODY_DOUBLE = """\
|
GETTER_BODY_DOUBLE = """\
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include "torch/csrc/autograd/generated/Functions.h"
|
#include "torch/csrc/autograd/generated/Functions.h"
|
||||||
#include "torch/csrc/autograd/python_cpp_function.h"
|
#include "torch/csrc/autograd/python_cpp_function.h"
|
||||||
#include <torch/csrc/autograd/python_variable.h>
|
#include <torch/csrc/autograd/python_variable.h>
|
||||||
|
@ -240,12 +240,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
|
|||||||
if (jit::tracer::isTracing()) {
|
if (jit::tracer::isTracing()) {
|
||||||
return wrap(jit::tracer::getNumelOf(self_));
|
return wrap(jit::tracer::getNumelOf(self_));
|
||||||
} else {
|
} else {
|
||||||
auto si = self_.sym_numel();
|
return py::cast(self_.sym_numel()).release().ptr();
|
||||||
if (si.is_symbolic()) {
|
|
||||||
return py::cast(si.toSymIntNodeImpl()).release().ptr();
|
|
||||||
} else {
|
|
||||||
return THPUtils_packInt64(si.as_int_unchecked());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
@ -722,7 +722,7 @@ def gen_pyi(
|
|||||||
binop += "_"
|
binop += "_"
|
||||||
out_suffix = ""
|
out_suffix = ""
|
||||||
unsorted_tensor_method_hints[binop].append(
|
unsorted_tensor_method_hints[binop].append(
|
||||||
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode]{})"
|
"def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})"
|
||||||
" -> Tensor: ...".format(binop, out_suffix)
|
" -> Tensor: ...".format(binop, out_suffix)
|
||||||
)
|
)
|
||||||
for binop in ["add", "sub"]:
|
for binop in ["add", "sub"]:
|
||||||
@ -732,7 +732,7 @@ def gen_pyi(
|
|||||||
binop += "_"
|
binop += "_"
|
||||||
out_suffix = ""
|
out_suffix = ""
|
||||||
unsorted_tensor_method_hints[binop].append(
|
unsorted_tensor_method_hints[binop].append(
|
||||||
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode], "
|
"def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], "
|
||||||
"*, alpha: Optional[Number]=1{})"
|
"*, alpha: Optional[Number]=1{})"
|
||||||
" -> Tensor: ...".format(binop, out_suffix)
|
" -> Tensor: ...".format(binop, out_suffix)
|
||||||
)
|
)
|
||||||
|
@ -169,20 +169,6 @@ class Future(object):
|
|||||||
|
|
||||||
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
||||||
|
|
||||||
class SymIntNode(object):
|
|
||||||
def get_pyobj(self) -> Any: ...
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new_symint(obj) -> SymIntNode: ...
|
|
||||||
|
|
||||||
class SymFloatNode(object):
|
|
||||||
def get_pyobj(self) -> Any: ...
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new_symfloat(obj) -> SymFloatNode: ...
|
|
||||||
|
|
||||||
def __ceil__(self) -> SymIntNode: ...
|
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
|
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
|
||||||
class MobileOptimizerType:
|
class MobileOptimizerType:
|
||||||
...
|
...
|
||||||
|
@ -47,7 +47,7 @@ __all__ = [
|
|||||||
'is_deterministic_algorithms_warn_only_enabled',
|
'is_deterministic_algorithms_warn_only_enabled',
|
||||||
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
|
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
|
||||||
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
||||||
'set_warn_always', 'is_warn_always_enabled',
|
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||||
]
|
]
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
@ -196,6 +196,67 @@ else:
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch._C as _C
|
import torch._C as _C
|
||||||
|
|
||||||
|
class SymInt:
|
||||||
|
"""
|
||||||
|
Like an int (including magic methods), but redirects all operations on the
|
||||||
|
wrapped node. This is used in particular to symbolically record operations
|
||||||
|
in the symbolic shape workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import SymNode
|
||||||
|
assert isinstance(node, SymNode)
|
||||||
|
# This field MUST be named node; C++ binding code assumes that this
|
||||||
|
# class has a field named node that stores SymNode
|
||||||
|
self.node = node
|
||||||
|
|
||||||
|
# Magic methods installed later
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.node.bool_()
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.node.int_()
|
||||||
|
|
||||||
|
def __sym_float__(self):
|
||||||
|
return SymFloat(self.node.sym_float())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.node.str()
|
||||||
|
|
||||||
|
# For BC; direct access of node is OK too
|
||||||
|
def get_pyobj(self):
|
||||||
|
return self.node
|
||||||
|
|
||||||
|
class SymFloat:
|
||||||
|
"""
|
||||||
|
Like an float (including magic methods), but redirects all operations on the
|
||||||
|
wrapped node. This is used in particular to symbolically record operations
|
||||||
|
in the symbolic shape workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import SymNode
|
||||||
|
assert isinstance(node, SymNode)
|
||||||
|
# This field MUST be named node; C++ binding code assumes that this
|
||||||
|
# class has a field named node that stores SymNode
|
||||||
|
self.node = node
|
||||||
|
|
||||||
|
# Magic methods installed later
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.node.bool_()
|
||||||
|
|
||||||
|
def __sym_int__(self):
|
||||||
|
return SymInt(self.node.sym_int())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.node.str()
|
||||||
|
|
||||||
|
# For BC; direct access of node is OK too
|
||||||
|
def get_pyobj(self):
|
||||||
|
return self.node
|
||||||
|
|
||||||
# Check to see if we can load C extensions, and if not provide some guidance
|
# Check to see if we can load C extensions, and if not provide some guidance
|
||||||
# on what the problem might be.
|
# on what the problem might be.
|
||||||
try:
|
try:
|
||||||
@ -941,7 +1002,6 @@ from ._linalg_utils import ( # type: ignore[misc]
|
|||||||
lstsq,
|
lstsq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _register_device_module(device_type, module):
|
def _register_device_module(device_type, module):
|
||||||
r"""Register an external runtime module of the specific :attr:`device_type`
|
r"""Register an external runtime module of the specific :attr:`device_type`
|
||||||
supported by torch.
|
supported by torch.
|
||||||
@ -971,3 +1031,6 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
|
|||||||
import torch.cuda._sanitizer as csan
|
import torch.cuda._sanitizer as csan
|
||||||
|
|
||||||
csan.enable_cuda_sanitizer()
|
csan.enable_cuda_sanitizer()
|
||||||
|
|
||||||
|
# Populate magic methods on SymInt and SymFloat
|
||||||
|
import torch.fx.experimental.symbolic_shapes
|
||||||
|
@ -337,7 +337,7 @@ class TensorVariable(VariableTracker):
|
|||||||
from . import UserDefinedObjectVariable
|
from . import UserDefinedObjectVariable
|
||||||
|
|
||||||
return UserDefinedObjectVariable(example_value)
|
return UserDefinedObjectVariable(example_value)
|
||||||
elif isinstance(example_value, torch.SymIntNode):
|
elif isinstance(example_value, torch.SymInt):
|
||||||
proxy.node.meta["example_value"] = example_value
|
proxy.node.meta["example_value"] = example_value
|
||||||
return cls(proxy, **options)
|
return cls(proxy, **options)
|
||||||
else:
|
else:
|
||||||
|
@ -40,11 +40,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
else:
|
else:
|
||||||
size, stride = self._shape_env.create_symbolic_sizes_strides(ex)
|
size, stride = self._shape_env.create_symbolic_sizes_strides(ex)
|
||||||
|
|
||||||
size = [
|
size = [i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in size]
|
||||||
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in size
|
|
||||||
]
|
|
||||||
stride = [
|
stride = [
|
||||||
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in stride
|
i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in stride
|
||||||
]
|
]
|
||||||
return size, stride
|
return size, stride
|
||||||
|
|
||||||
|
@ -392,8 +392,8 @@ def _elementwise_meta(
|
|||||||
# Number case
|
# Number case
|
||||||
# NOTE: this case is not currently exercised
|
# NOTE: this case is not currently exercised
|
||||||
# TODO: fix number type promotion (bool, complex->float)
|
# TODO: fix number type promotion (bool, complex->float)
|
||||||
assert not isinstance(number, torch.SymIntNode), "NYI"
|
assert not isinstance(number, torch.SymInt), "NYI"
|
||||||
assert not isinstance(number, torch.SymFloatNode), "NYI"
|
assert not isinstance(number, torch.SymFloat), "NYI"
|
||||||
return TensorMeta(number)
|
return TensorMeta(number)
|
||||||
|
|
||||||
|
|
||||||
@ -932,7 +932,7 @@ bitwise_xor = _make_elementwise_binary_prim(
|
|||||||
# div prim performs truncation division on integer inputs
|
# div prim performs truncation division on integer inputs
|
||||||
# and true division for floating and complex inputs
|
# and true division for floating and complex inputs
|
||||||
def _div_aten(a, b):
|
def _div_aten(a, b):
|
||||||
is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or (
|
is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
|
||||||
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
|
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,18 +42,18 @@ ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
|
|||||||
StrideType = Union[List[int], Tuple[int, ...]]
|
StrideType = Union[List[int], Tuple[int, ...]]
|
||||||
DimsType = Union[int, List[int], Tuple[int, ...]]
|
DimsType = Union[int, List[int], Tuple[int, ...]]
|
||||||
DimsSequenceType = Union[List[int], Tuple[int, ...]]
|
DimsSequenceType = Union[List[int], Tuple[int, ...]]
|
||||||
# TODO: Type[torch.SymIntNode], Type[torch.SymFloatNode]
|
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
|
||||||
NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
|
NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
|
||||||
# TODO: This needs a lot more type annotations
|
# TODO: This needs a lot more type annotations
|
||||||
# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode]
|
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
|
||||||
NumberType = Union[bool, int, float, complex]
|
NumberType = Union[bool, int, float, complex]
|
||||||
|
|
||||||
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode)
|
Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
|
||||||
# I don't call it Integral because numbers.Integral includes bool, but IntLike
|
# I don't call it Integral because numbers.Integral includes bool, but IntLike
|
||||||
# does not
|
# does not
|
||||||
Dim = int
|
Dim = int
|
||||||
IntLike = (int, torch.SymIntNode)
|
IntLike = (int, torch.SymInt)
|
||||||
FloatLike = (float, torch.SymFloatNode)
|
FloatLike = (float, torch.SymFloat)
|
||||||
IntWithoutSymInt = int
|
IntWithoutSymInt = int
|
||||||
FloatWithoutSymFloat = float
|
FloatWithoutSymFloat = float
|
||||||
DeviceLikeType = Union[str, torch.device]
|
DeviceLikeType = Union[str, torch.device]
|
||||||
@ -1113,10 +1113,10 @@ class RETURN_TYPE(Enum):
|
|||||||
|
|
||||||
|
|
||||||
# TODO: when NumberType contains the sym types, can simplify this
|
# TODO: when NumberType contains the sym types, can simplify this
|
||||||
def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type:
|
def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
|
||||||
if isinstance(x, torch.SymIntNode):
|
if isinstance(x, torch.SymInt):
|
||||||
return int
|
return int
|
||||||
elif isinstance(x, torch.SymFloatNode):
|
elif isinstance(x, torch.SymFloat):
|
||||||
return float
|
return float
|
||||||
else:
|
else:
|
||||||
return type(x)
|
return type(x)
|
||||||
|
@ -656,7 +656,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
return args[0].fake_device
|
return args[0].fake_device
|
||||||
|
|
||||||
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
|
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
|
||||||
flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs))
|
flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
|
||||||
has_symbolic_sizes = (
|
has_symbolic_sizes = (
|
||||||
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors])
|
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors])
|
||||||
or len(flat_symints) > 0
|
or len(flat_symints) > 0
|
||||||
|
@ -59,7 +59,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
|
|||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!torch::jit::tracer::isTracing(),
|
!torch::jit::tracer::isTracing(),
|
||||||
"JIT Tracing of SymInts isn't supported");
|
"JIT Tracing of SymInts isn't supported");
|
||||||
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
|
auto py_symint = py::cast(si).release().ptr();
|
||||||
if (!py_symint)
|
if (!py_symint)
|
||||||
throw python_error();
|
throw python_error();
|
||||||
PyTuple_SET_ITEM(ret.get(), i, py_symint);
|
PyTuple_SET_ITEM(ret.get(), i, py_symint);
|
||||||
@ -98,7 +98,7 @@ static PyObject* THPSize_pynew(
|
|||||||
if (THPUtils_checkLong(item)) {
|
if (THPUtils_checkLong(item)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (torch::is_symint_node(item)) {
|
if (torch::is_symint(item)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
|
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
|
||||||
@ -135,7 +135,7 @@ static PyObject* THPSize_repr(THPSize* self) {
|
|||||||
auto item = PyTuple_GET_ITEM(self, i);
|
auto item = PyTuple_GET_ITEM(self, i);
|
||||||
auto ih = py::handle(item);
|
auto ih = py::handle(item);
|
||||||
|
|
||||||
repr += torch::is_symint_node(ih)
|
repr += torch::is_symint(ih)
|
||||||
? std::string(py::str(ih))
|
? std::string(py::str(ih))
|
||||||
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
|
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
|
||||||
}
|
}
|
||||||
|
@ -2646,9 +2646,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
|
|||||||
"Cannot call numel on a tensor with symbolic shapes/strides");
|
"Cannot call numel on a tensor with symbolic shapes/strides");
|
||||||
return self->sym_numel_default();
|
return self->sym_numel_default();
|
||||||
}
|
}
|
||||||
return torch::is_symint_node(out)
|
return torch::is_symint(out) ? out.cast<c10::SymInt>()
|
||||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
: c10::SymInt{py::cast<int64_t>(out)};
|
||||||
: c10::SymInt{py::cast<int64_t>(out)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
||||||
@ -2669,9 +2668,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
|||||||
if (out.is(py::none())) {
|
if (out.is(py::none())) {
|
||||||
return self->sym_storage_offset_default();
|
return self->sym_storage_offset_default();
|
||||||
}
|
}
|
||||||
return torch::is_symint_node(out)
|
return torch::is_symint(out) ? out.cast<c10::SymInt>()
|
||||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
: c10::SymInt{py::cast<int64_t>(out)};
|
||||||
: c10::SymInt{py::cast<int64_t>(out)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
||||||
@ -2701,9 +2699,8 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
|||||||
py::list symints;
|
py::list symints;
|
||||||
for (auto it = out.begin(); it != out.end(); it++) {
|
for (auto it = out.begin(); it != out.end(); it++) {
|
||||||
auto elm = *it;
|
auto elm = *it;
|
||||||
auto si = torch::is_symint_node(elm)
|
auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>()
|
||||||
? elm.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
: c10::SymInt{py::cast<int64_t>(elm)};
|
||||||
: c10::SymInt{py::cast<int64_t>(elm)};
|
|
||||||
symints.append(si.as_int_unchecked());
|
symints.append(si.as_int_unchecked());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
|
||||||
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
#include <torch/csrc/jit/codegen/onednn/interface.h>
|
||||||
#endif
|
#endif
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||||
#include <torch/csrc/jit/frontend/tracer.h>
|
#include <torch/csrc/jit/frontend/tracer.h>
|
||||||
#include <torch/csrc/jit/ir/irparser.h>
|
#include <torch/csrc/jit/ir/irparser.h>
|
||||||
@ -99,7 +99,6 @@
|
|||||||
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
|
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
|
||||||
#include <torch/csrc/utils/cpp_stacktraces.h>
|
#include <torch/csrc/utils/cpp_stacktraces.h>
|
||||||
|
|
||||||
#include <c10/core/SymFloat.h>
|
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/util/signal_handler.h>
|
#include <c10/util/signal_handler.h>
|
||||||
@ -126,249 +125,11 @@ using c10::Argument;
|
|||||||
using c10::FunctionSchema;
|
using c10::FunctionSchema;
|
||||||
using c10::SchemaArgType;
|
using c10::SchemaArgType;
|
||||||
using c10::SchemaArgument;
|
using c10::SchemaArgument;
|
||||||
using c10::SymFloat;
|
using c10::SymNode;
|
||||||
using c10::SymFloatNode;
|
|
||||||
using c10::SymIntNode;
|
|
||||||
using caffe2::serialize::PyTorchStreamReader;
|
using caffe2::serialize::PyTorchStreamReader;
|
||||||
using caffe2::serialize::PyTorchStreamWriter;
|
using caffe2::serialize::PyTorchStreamWriter;
|
||||||
using torch::utils::SchemaInfo;
|
using torch::utils::SchemaInfo;
|
||||||
|
|
||||||
static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
|
|
||||||
return torch::is_symint_node(b) ? b.cast<c10::SymIntNode>()
|
|
||||||
: a->wrap(b.cast<int64_t>());
|
|
||||||
}
|
|
||||||
|
|
||||||
static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) {
|
|
||||||
if (torch::is_symfloat_node(b)) {
|
|
||||||
return b.cast<c10::SymFloatNode>();
|
|
||||||
} else if (torch::is_symint_node(b)) {
|
|
||||||
return b.cast<c10::SymIntNode>()->sym_float();
|
|
||||||
} else {
|
|
||||||
return a->wrap(b.cast<double>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
|
|
||||||
public:
|
|
||||||
PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() {
|
|
||||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
|
||||||
pyobj.release().ptr(), getPyInterpreter());
|
|
||||||
};
|
|
||||||
|
|
||||||
virtual SymIntNode clone() override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr("clone")();
|
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode wrap(int64_t num) override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr("wrap")(num);
|
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool bool_() override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int64_t guard_int(const char* file, int64_t line) override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual int64_t int_() override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
return getPyObj().attr("__int__")().cast<int64_t>();
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode sym_float() override;
|
|
||||||
|
|
||||||
virtual std::string str() override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
return getPyObj().attr("__str__")().cast<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode dispatch_common_(
|
|
||||||
const char* fname,
|
|
||||||
const SymIntNode& other) {
|
|
||||||
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
|
|
||||||
TORCH_CHECK(pother);
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode dispatch_common_(const char* fname) {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr(fname)();
|
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode add(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode sub(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode mul(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymFloatNode truediv(const SymIntNode& other) override;
|
|
||||||
|
|
||||||
virtual SymIntNode floordiv(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode mod(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode eq(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode gt(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode lt(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode le(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode ge(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode min(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
virtual SymIntNode max(const SymIntNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode ceil() override {
|
|
||||||
return dispatch_common_(__FUNCTION__);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual SymIntNode neg() override {
|
|
||||||
return dispatch_common_(__FUNCTION__);
|
|
||||||
}
|
|
||||||
|
|
||||||
py::handle getPyObj() {
|
|
||||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
|
||||||
}
|
|
||||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
|
|
||||||
public:
|
|
||||||
PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() {
|
|
||||||
pyobj_ = std::make_shared<c10::SafePyObject>(
|
|
||||||
pyobj.release().ptr(), getPyInterpreter());
|
|
||||||
};
|
|
||||||
|
|
||||||
virtual SymFloatNode wrap(double num) override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr("wrap")(num);
|
|
||||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual std::string str() override {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
return getPyObj().attr("__str__")().cast<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) {
|
|
||||||
auto pother = dynamic_cast<PythonSymFloatNodeImpl*>(other.get());
|
|
||||||
TORCH_CHECK(pother);
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
|
||||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode add(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode sub(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode mul(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode truediv(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode pow(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode eq(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode gt(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode lt(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode le(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode ge(const SymFloatNode& other) override {
|
|
||||||
return dispatch_common_(__FUNCTION__, other);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymIntNode ceil() override;
|
|
||||||
SymIntNode floor() override;
|
|
||||||
|
|
||||||
py::handle getPyObj() {
|
|
||||||
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
|
||||||
}
|
|
||||||
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) {
|
|
||||||
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
|
|
||||||
TORCH_CHECK(pother);
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr("truediv")(pother->getPyObj());
|
|
||||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymFloatNode PythonSymIntNodeImpl::sym_float() {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(
|
|
||||||
getPyObj().attr("__sym_float__")());
|
|
||||||
}
|
|
||||||
|
|
||||||
SymIntNode PythonSymFloatNodeImpl::ceil() {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr("ceil")();
|
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
SymIntNode PythonSymFloatNodeImpl::floor() {
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
auto r = getPyObj().attr("floor")();
|
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using autograd::variable_list;
|
using autograd::variable_list;
|
||||||
@ -1381,276 +1142,41 @@ void initJITBindings(PyObject* module) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
auto symint_class =
|
// NB: This isn't actually used for regular PyTorch symbolic tracing;
|
||||||
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
|
// XLA is what needs this
|
||||||
.def_static(
|
#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
|
||||||
"new_symint",
|
#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); })
|
||||||
[](py::object obj) -> c10::SymIntNode {
|
#define SYMNODE_BINARY(n) \
|
||||||
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
|
.def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
|
||||||
})
|
auto symnode_class =
|
||||||
.def(
|
py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
|
||||||
"get_pyobj",
|
// These DO NOT install magic methods; the SymInt/SymFloat wrapper in
|
||||||
[](c10::SymIntNode a) -> py::object {
|
// Python is responsible for this
|
||||||
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
|
SYMNODE_UNARY(clone)
|
||||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
// Named these for consistency with inner python class, but maybe
|
||||||
}
|
// should change the python side
|
||||||
return py::none();
|
SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_)
|
||||||
})
|
SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2(
|
||||||
.def(
|
__sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub)
|
||||||
"__add__",
|
SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow)
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(
|
||||||
auto snb = toSymIntNode(a, b);
|
eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt)
|
||||||
return a->add(snb);
|
SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min)
|
||||||
})
|
SYMNODE_BINARY(max) SYMNODE_UNARY(ceil)
|
||||||
.def(
|
SYMNODE_UNARY(floor) SYMNODE_UNARY(neg)
|
||||||
"__radd__",
|
// Intentionally don't set file line, as the
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
// Python backtrace matters more here
|
||||||
auto snb = toSymIntNode(a, b);
|
.def(
|
||||||
return snb->add(a);
|
"guard_int",
|
||||||
})
|
[](c10::SymNode a) {
|
||||||
.def(
|
return a->guard_int(nullptr, 0);
|
||||||
"__sub__",
|
})
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
.def(
|
||||||
auto snb = toSymIntNode(a, b);
|
"__str__",
|
||||||
return a->sub(snb);
|
[](c10::SymNode a) { return a->str(); })
|
||||||
})
|
.def("__repr__", [](c10::SymNode a) {
|
||||||
.def(
|
return a->str();
|
||||||
"__rsub__",
|
});
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return snb->sub(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__mul__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->mul(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rmul__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return snb->mul(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__truediv__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->truediv(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rtruediv__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return snb->truediv(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__floordiv__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->floordiv(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rfloordiv__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return snb->floordiv(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__mod__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->mod(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rmod__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return snb->mod(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__pow__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> py::object {
|
|
||||||
if (PyFloat_Check(b.ptr())) {
|
|
||||||
auto float_a = a->sym_float();
|
|
||||||
return py::cast(
|
|
||||||
float_a->pow(float_a->wrap(py::cast<double>(b))));
|
|
||||||
}
|
|
||||||
// TODO: integer pow
|
|
||||||
return py::reinterpret_borrow<py::object>(Py_NotImplemented);
|
|
||||||
})
|
|
||||||
// TODO: rpow
|
|
||||||
.def(
|
|
||||||
"__eq__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->eq(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__gt__",
|
|
||||||
[](c10::SymIntNode a, py::object b) {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->gt(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__lt__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->lt(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__le__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->le(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__ge__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->ge(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__ceil__",
|
|
||||||
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
|
|
||||||
.def(
|
|
||||||
"__neg__",
|
|
||||||
[](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); })
|
|
||||||
.def(
|
|
||||||
"__min__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->min(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__max__",
|
|
||||||
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
|
|
||||||
auto snb = toSymIntNode(a, b);
|
|
||||||
return a->max(snb);
|
|
||||||
})
|
|
||||||
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
|
|
||||||
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
|
|
||||||
// Intentionally don't set file line, as the Python backtrace matters
|
|
||||||
// more here
|
|
||||||
.def(
|
|
||||||
"guard_int",
|
|
||||||
[](c10::SymIntNode a) { return a->guard_int(nullptr, 0); })
|
|
||||||
.def(
|
|
||||||
"__sym_float__",
|
|
||||||
[](c10::SymIntNode a) {
|
|
||||||
// TODO: remove dynamic cast when sym_float is in base class
|
|
||||||
auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get());
|
|
||||||
TORCH_INTERNAL_ASSERT(psn);
|
|
||||||
return psn->sym_float();
|
|
||||||
})
|
|
||||||
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
|
|
||||||
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
|
|
||||||
|
|
||||||
py::class_<c10::SymFloatNodeImpl, c10::SymFloatNode>(m, "SymFloatNode")
|
|
||||||
.def_static(
|
|
||||||
"new_symfloat",
|
|
||||||
[](py::object obj) -> c10::SymFloatNode {
|
|
||||||
return c10::make_intrusive<PythonSymFloatNodeImpl>(obj);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__add__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->add(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__radd__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return snb->add(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__sub__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->sub(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__mul__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->mul(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rmul__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return snb->mul(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__truediv__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->truediv(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rtruediv__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return snb->truediv(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__eq__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->eq(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__gt__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->gt(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__lt__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->lt(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__le__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->le(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__ge__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->ge(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__pow__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return a->pow(snb);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__rpow__",
|
|
||||||
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
|
|
||||||
auto snb = toSymFloatNode(a, b);
|
|
||||||
return snb->pow(a);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"__ceil__",
|
|
||||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
|
|
||||||
.def(
|
|
||||||
"__floor__",
|
|
||||||
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
|
|
||||||
.def(
|
|
||||||
"get_pyobj",
|
|
||||||
[](c10::SymFloatNode a) -> py::object {
|
|
||||||
if (auto* psn = dynamic_cast<PythonSymFloatNodeImpl*>(a.get())) {
|
|
||||||
return py::reinterpret_borrow<py::object>(psn->getPyObj());
|
|
||||||
}
|
|
||||||
return py::none();
|
|
||||||
})
|
|
||||||
.def("__str__", [](c10::SymFloatNode a) { return a->str(); });
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||||
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
|
||||||
|
@ -80,10 +80,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
|||||||
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
|
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
|
||||||
} else if (THPUtils_checkDouble(obj.ptr())) {
|
} else if (THPUtils_checkDouble(obj.ptr())) {
|
||||||
scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
|
scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
|
||||||
} else if (torch::is_symint_node(py::handle(obj))) {
|
} else if (torch::is_symint(py::handle(obj))) {
|
||||||
save_symint = true;
|
save_symint = true;
|
||||||
scalar = at::Scalar(7777777);
|
scalar = at::Scalar(7777777);
|
||||||
} else if (torch::is_symfloat_node(py::handle(obj))) {
|
} else if (torch::is_symfloat(py::handle(obj))) {
|
||||||
save_symint = true;
|
save_symint = true;
|
||||||
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
||||||
} else {
|
} else {
|
||||||
@ -161,12 +161,12 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
|||||||
return py::cast<int64_t>(obj);
|
return py::cast<int64_t>(obj);
|
||||||
}
|
}
|
||||||
case TypeKind::SymIntType:
|
case TypeKind::SymIntType:
|
||||||
if (torch::is_symint_node(obj.ptr())) {
|
if (torch::is_symint(obj.ptr())) {
|
||||||
return py::cast<c10::SymInt>(obj);
|
return py::cast<c10::SymInt>(obj);
|
||||||
}
|
}
|
||||||
return py::cast<int64_t>(obj);
|
return py::cast<int64_t>(obj);
|
||||||
case TypeKind::SymFloatType:
|
case TypeKind::SymFloatType:
|
||||||
if (torch::is_symfloat_node(obj.ptr())) {
|
if (torch::is_symfloat(obj.ptr())) {
|
||||||
return py::cast<c10::SymFloat>(obj);
|
return py::cast<c10::SymFloat>(obj);
|
||||||
}
|
}
|
||||||
return py::cast<double>(obj);
|
return py::cast<double>(obj);
|
||||||
@ -253,7 +253,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
|||||||
bool is_symbolic = false;
|
bool is_symbolic = false;
|
||||||
for (auto it = obj.begin(); it != obj.end(); it++) {
|
for (auto it = obj.begin(); it != obj.end(); it++) {
|
||||||
auto elm = *it;
|
auto elm = *it;
|
||||||
if (torch::is_symint_node(elm)) {
|
if (torch::is_symint(elm)) {
|
||||||
is_symbolic = true;
|
is_symbolic = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -269,7 +269,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
|||||||
for (auto it = obj.begin(); it != obj.end(); it++) {
|
for (auto it = obj.begin(); it != obj.end(); it++) {
|
||||||
auto elm = *it;
|
auto elm = *it;
|
||||||
// TODO: what about SymInt conversion to SymFloat?
|
// TODO: what about SymInt conversion to SymFloat?
|
||||||
if (torch::is_symfloat_node(elm)) {
|
if (torch::is_symfloat(elm)) {
|
||||||
is_symbolic = true;
|
is_symbolic = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -442,9 +442,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
|||||||
} else if (PyComplex_CheckExact(obj.ptr())) {
|
} else if (PyComplex_CheckExact(obj.ptr())) {
|
||||||
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
|
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
|
||||||
return static_cast<c10::complex<double>>(c_obj);
|
return static_cast<c10::complex<double>>(c_obj);
|
||||||
} else if (torch::is_symint_node(obj)) {
|
} else if (torch::is_symint(obj)) {
|
||||||
return py::cast<c10::SymInt>(obj);
|
return py::cast<c10::SymInt>(obj);
|
||||||
} else if (torch::is_symfloat_node(obj)) {
|
} else if (torch::is_symfloat(obj)) {
|
||||||
return py::cast<c10::SymFloat>(obj);
|
return py::cast<c10::SymFloat>(obj);
|
||||||
} else {
|
} else {
|
||||||
throw py::cast_error(
|
throw py::cast_error(
|
||||||
|
@ -136,10 +136,10 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
|
|||||||
|
|
||||||
inline Value GetSymIntValue(c10::SymInt a) {
|
inline Value GetSymIntValue(c10::SymInt a) {
|
||||||
return Value(
|
return Value(
|
||||||
a.is_symbolic() ? dynamic_cast<torch::lazy::SymIntNodeImpl*>(
|
a.is_symbolic()
|
||||||
a.toSymIntNodeImpl().get())
|
? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get())
|
||||||
->node_
|
->node_
|
||||||
: MakeScalar(a.as_int_unchecked(), at::kLong),
|
: MakeScalar(a.as_int_unchecked(), at::kLong),
|
||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -451,11 +451,11 @@ std::vector<Shape> compute_shape_expand(
|
|||||||
std::vector<int64_t> target_size(_sizes.size());
|
std::vector<int64_t> target_size(_sizes.size());
|
||||||
for (const auto idx : c10::irange(_sizes.size())) {
|
for (const auto idx : c10::irange(_sizes.size())) {
|
||||||
if (_sizes[idx].is_symbolic()) {
|
if (_sizes[idx].is_symbolic()) {
|
||||||
c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl();
|
c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl();
|
||||||
auto* lazySymIntNode =
|
auto* lazySymNode =
|
||||||
dynamic_cast<torch::lazy::SymIntNodeImpl*>(symbolicIntNode.get());
|
dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get());
|
||||||
TORCH_INTERNAL_ASSERT(lazySymIntNode);
|
TORCH_INTERNAL_ASSERT(lazySymNode);
|
||||||
auto size_node = lazySymIntNode->node_;
|
auto size_node = lazySymNode->node_;
|
||||||
auto static_value =
|
auto static_value =
|
||||||
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
|
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
|
||||||
->getStaticValue();
|
->getStaticValue();
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
#include <c10/core/SymInt.h>
|
#include <c10/core/SymInt.h>
|
||||||
#include <c10/core/SymIntArrayRef.h>
|
#include <c10/core/SymIntArrayRef.h>
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||||
@ -10,12 +10,9 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl {
|
class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
|
||||||
public:
|
public:
|
||||||
SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
|
SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
|
||||||
c10::SymIntNode add(const c10::SymIntNode& other) override {
|
|
||||||
TORCH_CHECK(false, "NYI");
|
|
||||||
}
|
|
||||||
NodePtr node_;
|
NodePtr node_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -685,7 +685,7 @@ static bool is_int_list(
|
|||||||
// NB: do NOT check that the later arguments are ints, as this is
|
// NB: do NOT check that the later arguments are ints, as this is
|
||||||
// BC-breaking for FX
|
// BC-breaking for FX
|
||||||
for (int i = 1; i < len; i++) {
|
for (int i = 1; i < len; i++) {
|
||||||
if (torch::is_symint_node(
|
if (torch::is_symint(
|
||||||
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
|
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
|
||||||
if (failed_idx != nullptr) {
|
if (failed_idx != nullptr) {
|
||||||
*failed_idx = i;
|
*failed_idx = i;
|
||||||
@ -716,9 +716,9 @@ static bool is_int_list(
|
|||||||
static bool is_int_or_symint(PyObject* obj) {
|
static bool is_int_or_symint(PyObject* obj) {
|
||||||
// THPUtils_checkIndex may call __index__ or __int__
|
// THPUtils_checkIndex may call __index__ or __int__
|
||||||
// which may have side effects if obj is a symint node
|
// which may have side effects if obj is a symint node
|
||||||
// so we do `is_symint_node` check first
|
// so we do `is_symint` check first
|
||||||
// TODO: maybe we should be using checkLong here?
|
// TODO: maybe we should be using checkLong here?
|
||||||
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj);
|
return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool is_int_or_symint_list(
|
static bool is_int_or_symint_list(
|
||||||
@ -1570,13 +1570,13 @@ at::Tensor PythonArgs::tensor_slow(int i) {
|
|||||||
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
|
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
|
||||||
// because although Scalar supports SymInt/SymFloat, the subsequent
|
// because although Scalar supports SymInt/SymFloat, the subsequent
|
||||||
// conversion to Tensor does not. Instead, do it out of band.
|
// conversion to Tensor does not. Instead, do it out of band.
|
||||||
} else if (torch::is_symint_node(py::handle(obj))) {
|
} else if (torch::is_symint(py::handle(obj))) {
|
||||||
save_symint = true;
|
save_symint = true;
|
||||||
// This scalar value doesn't matter, it shouldn't ever actually
|
// This scalar value doesn't matter, it shouldn't ever actually
|
||||||
// get read out. Make it a big and weird looking number to help
|
// get read out. Make it a big and weird looking number to help
|
||||||
// people figure out if there's aproblem.
|
// people figure out if there's aproblem.
|
||||||
scalar = at::Scalar(7777777);
|
scalar = at::Scalar(7777777);
|
||||||
} else if (torch::is_symfloat_node(py::handle(obj))) {
|
} else if (torch::is_symfloat(py::handle(obj))) {
|
||||||
save_symint = true;
|
save_symint = true;
|
||||||
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
||||||
} else {
|
} else {
|
||||||
@ -1633,11 +1633,11 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
|
|||||||
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (torch::is_symint_node(arg)) {
|
if (torch::is_symint(arg)) {
|
||||||
return at::Scalar(py::cast<c10::SymInt>(arg));
|
return at::Scalar(py::cast<c10::SymInt>(arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (torch::is_symfloat_node(arg)) {
|
if (torch::is_symfloat(arg)) {
|
||||||
return at::Scalar(py::cast<c10::SymFloat>(arg));
|
return at::Scalar(py::cast<c10::SymFloat>(arg));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@
|
|||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
#include <torch/csrc/utils/python_numbers.h>
|
#include <torch/csrc/utils/python_numbers.h>
|
||||||
#include <torch/csrc/utils/python_strings.h>
|
#include <torch/csrc/utils/python_strings.h>
|
||||||
|
#include <torch/csrc/utils/python_symnode.h>
|
||||||
#include <torch/csrc/utils/six.h>
|
#include <torch/csrc/utils/six.h>
|
||||||
|
|
||||||
#include <ATen/PythonTorchFunctionTLS.h>
|
#include <ATen/PythonTorchFunctionTLS.h>
|
||||||
@ -69,7 +70,7 @@
|
|||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#include <c10/core/SymFloat.h>
|
#include <c10/core/SymFloat.h>
|
||||||
#include <c10/core/SymIntNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
@ -78,30 +79,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
|
|
||||||
inline bool is_symint_node(py::handle obj) {
|
|
||||||
auto static tp_symn = py::type::of<c10::SymIntNodeImpl>();
|
|
||||||
if (py::isinstance(obj, tp_symn)) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
!jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_symfloat_node(py::handle obj) {
|
|
||||||
auto static tp_symn = py::type::of<c10::SymFloatNodeImpl>();
|
|
||||||
if (py::isinstance(obj, tp_symn)) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
!jit::tracer::isTracing(), "JIT tracing of SymFloats isn't supported!");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace torch
|
|
||||||
|
|
||||||
namespace pybind11 {
|
namespace pybind11 {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
template <>
|
template <>
|
||||||
@ -109,8 +86,10 @@ struct type_caster<c10::SymInt> {
|
|||||||
public:
|
public:
|
||||||
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
|
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
|
||||||
bool load(py::handle src, bool) {
|
bool load(py::handle src, bool) {
|
||||||
if (torch::is_symint_node(src)) {
|
if (torch::is_symint(src)) {
|
||||||
value = src.cast<c10::SymIntNodeImpl*>()->toSymInt();
|
value = c10::SymInt(static_cast<c10::SymNode>(
|
||||||
|
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
|
||||||
|
src.attr("node"))));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,8 +105,15 @@ struct type_caster<c10::SymInt> {
|
|||||||
c10::SymInt si,
|
c10::SymInt si,
|
||||||
return_value_policy /* policy */,
|
return_value_policy /* policy */,
|
||||||
handle /* parent */) {
|
handle /* parent */) {
|
||||||
return si.is_symbolic() ? py::cast(si.toSymIntNodeImpl()).release()
|
if (si.is_symbolic()) {
|
||||||
: py::cast(si.expect_int()).release();
|
// TODO: generalize this to work with C++ backed class
|
||||||
|
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||||
|
si.toSymNodeImpl().get());
|
||||||
|
TORCH_INTERNAL_ASSERT(py_node);
|
||||||
|
return torch::get_symint_class()(py_node->getPyObj()).release();
|
||||||
|
} else {
|
||||||
|
return py::cast(si.as_int_unchecked()).release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -136,8 +122,10 @@ struct type_caster<c10::SymFloat> {
|
|||||||
public:
|
public:
|
||||||
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
|
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
|
||||||
bool load(py::handle src, bool) {
|
bool load(py::handle src, bool) {
|
||||||
if (torch::is_symfloat_node(src)) {
|
if (torch::is_symfloat(src)) {
|
||||||
value = src.cast<c10::SymFloatNodeImpl*>()->toSymFloat();
|
value = c10::SymFloat(static_cast<c10::SymNode>(
|
||||||
|
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
|
||||||
|
src.attr("node"))));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,8 +141,15 @@ struct type_caster<c10::SymFloat> {
|
|||||||
c10::SymFloat si,
|
c10::SymFloat si,
|
||||||
return_value_policy /* policy */,
|
return_value_policy /* policy */,
|
||||||
handle /* parent */) {
|
handle /* parent */) {
|
||||||
return si.is_symbolic() ? py::cast(si.toSymFloatNodeImpl()).release()
|
if (si.is_symbolic()) {
|
||||||
: py::cast(si.expect_float()).release();
|
// TODO: generalize this to work with C++ backed class
|
||||||
|
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
|
||||||
|
si.toSymNodeImpl().get());
|
||||||
|
TORCH_INTERNAL_ASSERT(py_node);
|
||||||
|
return torch::get_symfloat_class()(py_node->getPyObj()).release();
|
||||||
|
} else {
|
||||||
|
return py::cast(si.as_float_unchecked()).release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
@ -167,8 +162,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
|
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
|
||||||
torch::is_symint_node(py::handle(obj)) ||
|
torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj));
|
||||||
torch::is_symfloat_node(py::handle(obj));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
@ -574,7 +568,7 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
|
|||||||
|
|
||||||
inline PyObject* toPyObject(c10::SymInt symint) {
|
inline PyObject* toPyObject(c10::SymInt symint) {
|
||||||
if (symint.is_symbolic()) {
|
if (symint.is_symbolic()) {
|
||||||
auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr();
|
auto r = py::cast(symint).release().ptr();
|
||||||
TORCH_INTERNAL_ASSERT(r);
|
TORCH_INTERNAL_ASSERT(r);
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
@ -609,8 +603,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
|||||||
size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
|
size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) {
|
if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
|
||||||
auto si = py::handle(args[i]).cast<c10::SymIntNodeImpl*>()->toSymInt();
|
auto si = py::handle(args[i]).cast<c10::SymInt>();
|
||||||
return std::vector<c10::SymInt>(size1, si);
|
return std::vector<c10::SymInt>(size1, si);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -652,9 +646,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
|||||||
res.push_back(var.item<int64_t>());
|
res.push_back(var.item<int64_t>());
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
if (is_symint_node(py::handle(obj))) {
|
if (is_symint(py::handle(obj))) {
|
||||||
res.push_back(
|
res.push_back(py::handle(obj).cast<c10::SymInt>());
|
||||||
py::handle(obj).cast<c10::SymIntNodeImpl*>()->toSymInt());
|
|
||||||
} else {
|
} else {
|
||||||
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
|
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
|
||||||
}
|
}
|
||||||
|
19
torch/csrc/utils/python_symnode.cpp
Normal file
19
torch/csrc/utils/python_symnode.cpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#include <torch/csrc/utils/python_symnode.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
|
||||||
|
py::handle get_symint_class() {
|
||||||
|
// NB: leak
|
||||||
|
static py::handle symint_class =
|
||||||
|
py::object(py::module::import("torch").attr("SymInt")).release();
|
||||||
|
return symint_class;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle get_symfloat_class() {
|
||||||
|
// NB: leak
|
||||||
|
static py::handle symfloat_class =
|
||||||
|
py::object(py::module::import("torch").attr("SymFloat")).release();
|
||||||
|
return symfloat_class;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch
|
182
torch/csrc/utils/python_symnode.h
Normal file
182
torch/csrc/utils/python_symnode.h
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/core/SafePyObject.h>
|
||||||
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/autograd/python_variable.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
|
||||||
|
TORCH_PYTHON_API py::handle get_symint_class();
|
||||||
|
TORCH_PYTHON_API py::handle get_symfloat_class();
|
||||||
|
|
||||||
|
// NB: These functions must not be called too early, otherwise torch not setup.
|
||||||
|
// Alternate design is to have torch "register" the object to us
|
||||||
|
inline bool is_symint(py::handle obj) {
|
||||||
|
return py::isinstance(obj, get_symint_class());
|
||||||
|
}
|
||||||
|
inline bool is_symfloat(py::handle obj) {
|
||||||
|
return py::isinstance(obj, get_symfloat_class());
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
// This c10::SymNodeImpl simply backends to a Python object that
|
||||||
|
// implements the API. The Python object is the source of truth,
|
||||||
|
// this is just an adapter so C++ calls can get to the object.
|
||||||
|
class PythonSymNodeImpl : public c10::SymNodeImpl {
|
||||||
|
public:
|
||||||
|
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
|
||||||
|
pyobj_ = std::make_shared<c10::SafePyObject>(
|
||||||
|
pyobj.release().ptr(), getPyInterpreter());
|
||||||
|
};
|
||||||
|
|
||||||
|
c10::SymNode wrap_int(int64_t num) override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
auto r = getPyObj().attr("wrap_int")(num);
|
||||||
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode wrap_float(double num) override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
auto r = getPyObj().attr("wrap_float")(num);
|
||||||
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bool_() override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("bool_")().is(py::handle(Py_True));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_int() override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("is_int")().is(py::handle(Py_True));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_float() override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("is_float")().is(py::handle(Py_True));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t guard_int(const char* file, int64_t line) override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
double guard_float(const char* file, int64_t line) override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("guard_float")(file, line).cast<double>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t int_() override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("int_")().cast<int64_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string str() override {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
return getPyObj().attr("str")().cast<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
|
||||||
|
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
|
||||||
|
TORCH_CHECK(pother);
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
||||||
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode dispatch_common_(const char* fname) {
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
auto r = getPyObj().attr(fname)();
|
||||||
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode add(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode sub(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode mul(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode truediv(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode pow(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode floordiv(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode mod(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode eq(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode gt(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode lt(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode le(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode ge(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode min(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
c10::SymNode max(const c10::SymNode& other) override {
|
||||||
|
return dispatch_common_(__FUNCTION__, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode ceil() override {
|
||||||
|
return dispatch_common_(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode floor() override {
|
||||||
|
return dispatch_common_(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode neg() override {
|
||||||
|
return dispatch_common_(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode clone() override {
|
||||||
|
return dispatch_common_(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode sym_int() override {
|
||||||
|
return dispatch_common_(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::SymNode sym_float() override {
|
||||||
|
return dispatch_common_(__FUNCTION__);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::handle getPyObj() {
|
||||||
|
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
|
||||||
|
}
|
||||||
|
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace impl
|
||||||
|
} // namespace torch
|
@ -21,8 +21,9 @@ import operator
|
|||||||
|
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
|
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
|
||||||
from torch._subclasses import FakeTensor
|
from torch._subclasses import FakeTensor
|
||||||
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
|
from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
|
||||||
from torch.fx import Proxy
|
from torch.fx import Proxy
|
||||||
|
from torch import SymInt, SymFloat
|
||||||
|
|
||||||
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"]
|
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"]
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
@ -55,27 +56,27 @@ def decompose(decomposition_table):
|
|||||||
proxy_slot = object()
|
proxy_slot = object()
|
||||||
no_default = object()
|
no_default = object()
|
||||||
|
|
||||||
py_sym_types = (
|
py_sym_types = (SymInt, SymFloat)
|
||||||
PySymInt,
|
|
||||||
PySymFloat,
|
|
||||||
)
|
|
||||||
def is_sym_node(node):
|
def is_sym_node(node):
|
||||||
assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
|
assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
|
||||||
return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
|
return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
|
||||||
|
|
||||||
def set_proxy_slot(obj, tracer, proxy):
|
def set_proxy_slot(obj, tracer, proxy):
|
||||||
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary())
|
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
|
||||||
|
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload]
|
||||||
assert isinstance(d, weakref.WeakKeyDictionary)
|
assert isinstance(d, weakref.WeakKeyDictionary)
|
||||||
d[tracer] = proxy
|
d[tracer] = proxy
|
||||||
|
|
||||||
def has_proxy_slot(obj, tracer):
|
def has_proxy_slot(obj, tracer):
|
||||||
|
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
|
||||||
return get_proxy_slot(obj, tracer, False, lambda _: True)
|
return get_proxy_slot(obj, tracer, False, lambda _: True)
|
||||||
|
|
||||||
# the default argument is what to return if the slot is not set.
|
# the default argument is what to return if the slot is not set.
|
||||||
# the transform argument is handy if you need to extract a subfield from
|
# the transform argument is handy if you need to extract a subfield from
|
||||||
# the successfully looked up result (but NOT the default.)
|
# the successfully looked up result (but NOT the default.)
|
||||||
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
|
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
|
||||||
d = obj.__dict__.get(proxy_slot)
|
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
|
||||||
|
d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload]
|
||||||
if not d:
|
if not d:
|
||||||
if default is no_default:
|
if default is no_default:
|
||||||
raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
|
raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
|
||||||
@ -130,10 +131,8 @@ def track_tensor(tensor, proxy, *, constant, tracer):
|
|||||||
def try_set_proxy_slot(outer_s, proxy_callable, *args):
|
def try_set_proxy_slot(outer_s, proxy_callable, *args):
|
||||||
assert callable(proxy_callable)
|
assert callable(proxy_callable)
|
||||||
if isinstance(outer_s, SymInt):
|
if isinstance(outer_s, SymInt):
|
||||||
inner_s = outer_s.get_pyobj()
|
inner_s = outer_s.node
|
||||||
assert isinstance(inner_s, py_sym_types)
|
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, outer_s, *args))
|
||||||
|
|
||||||
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args))
|
|
||||||
|
|
||||||
# The basic idea is that we need to associate each tensor/SymInt
|
# The basic idea is that we need to associate each tensor/SymInt
|
||||||
# with a Proxy. How do we setup this association? We just store
|
# with a Proxy. How do we setup this association? We just store
|
||||||
@ -198,7 +197,7 @@ class _ProxyTensor:
|
|||||||
|
|
||||||
def fetch_sym_proxy(tracer):
|
def fetch_sym_proxy(tracer):
|
||||||
def inner(e):
|
def inner(e):
|
||||||
n = e.get_pyobj()
|
n = e.node
|
||||||
if n.constant is not None:
|
if n.constant is not None:
|
||||||
return n.constant
|
return n.constant
|
||||||
else:
|
else:
|
||||||
@ -400,8 +399,8 @@ class PythonKeyTracer(Tracer):
|
|||||||
|
|
||||||
return self.create_node('get_attr', qualname, (), {})
|
return self.create_node('get_attr', qualname, (), {})
|
||||||
elif isinstance(a, (SymInt, SymFloat)):
|
elif isinstance(a, (SymInt, SymFloat)):
|
||||||
assert a.get_pyobj().constant is not None
|
assert a.node.constant is not None
|
||||||
return a.get_pyobj().constant
|
return a.node.constant
|
||||||
return super().create_arg(a)
|
return super().create_arg(a)
|
||||||
|
|
||||||
|
|
||||||
@ -432,7 +431,7 @@ def wrap_key(f, tensors, tracer):
|
|||||||
)
|
)
|
||||||
out = pytree.tree_map_only(
|
out = pytree.tree_map_only(
|
||||||
(SymInt, SymFloat),
|
(SymInt, SymFloat),
|
||||||
lambda t: get_proxy_slot(t.get_pyobj(), tracer)(),
|
lambda t: get_proxy_slot(t.node, tracer)(),
|
||||||
out
|
out
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
@ -479,10 +478,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
SymInt = torch.SymIntNode
|
|
||||||
SymFloat = torch.SymFloatNode
|
|
||||||
|
|
||||||
|
|
||||||
class ProxySymDispatchMode(SymDispatchMode):
|
class ProxySymDispatchMode(SymDispatchMode):
|
||||||
def __init__(self, tracer):
|
def __init__(self, tracer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -501,10 +496,9 @@ class ProxySymDispatchMode(SymDispatchMode):
|
|||||||
finally:
|
finally:
|
||||||
self.enable_tracing = old
|
self.enable_tracing = old
|
||||||
|
|
||||||
def _compute_proxy(self, func, args, out):
|
def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat]):
|
||||||
n_args = tuple(
|
n_args = tuple(
|
||||||
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant
|
get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a
|
||||||
if isinstance(a, py_sym_types) else a
|
|
||||||
for a in args
|
for a in args
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -520,10 +514,11 @@ class ProxySymDispatchMode(SymDispatchMode):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
# Peephole optimize multiply by one
|
# Peephole optimize multiply by one
|
||||||
|
# NB: be careful not to trigger guards here!
|
||||||
if func == operator.mul:
|
if func == operator.mul:
|
||||||
if isinstance(args[1], (PySymInt, PySymFloat)) and args[1].constant == 1:
|
if isinstance(args[1], int) and args[1] == 1:
|
||||||
return args[0]
|
return args[0]
|
||||||
elif isinstance(args[0], (PySymInt, PySymFloat)) and args[0].constant == 1:
|
elif isinstance(args[0], int) and args[0] == 1:
|
||||||
return args[1]
|
return args[1]
|
||||||
|
|
||||||
# For speed, we assume there are no nested data structures
|
# For speed, we assume there are no nested data structures
|
||||||
@ -535,7 +530,7 @@ class ProxySymDispatchMode(SymDispatchMode):
|
|||||||
|
|
||||||
# Delays tracing out the proxies on this op until we actually need it
|
# Delays tracing out the proxies on this op until we actually need it
|
||||||
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
|
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
|
||||||
set_proxy_slot(out, self.tracer, p_out_thunk)
|
set_proxy_slot(out.node, self.tracer, p_out_thunk)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import traceback
|
|||||||
import collections
|
import collections
|
||||||
import textwrap
|
import textwrap
|
||||||
from torch._subclasses.meta_utils import MetaConverter
|
from torch._subclasses.meta_utils import MetaConverter
|
||||||
|
from torch import SymInt, SymFloat
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sympy # type: ignore[import]
|
import sympy # type: ignore[import]
|
||||||
@ -21,8 +22,8 @@ except ImportError:
|
|||||||
aten = torch.ops.aten # type: ignore[has-type]
|
aten = torch.ops.aten # type: ignore[has-type]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv",
|
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
|
||||||
"SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv"
|
"SymDispatchMode", "sym_float", "FloorDiv", "guard_int", "wrap_node"
|
||||||
]
|
]
|
||||||
|
|
||||||
SYM_FUNCTION_MODE = None
|
SYM_FUNCTION_MODE = None
|
||||||
@ -88,32 +89,38 @@ def _handle_sym_dispatch(func, args, kwargs):
|
|||||||
finally:
|
finally:
|
||||||
SYM_FUNCTION_MODE = mode
|
SYM_FUNCTION_MODE = mode
|
||||||
|
|
||||||
|
def guard_int(a):
|
||||||
|
if isinstance(a, SymInt):
|
||||||
|
return a.node.guard_int("", 0) # NB: uses Python backtrace
|
||||||
|
assert isinstance(a, int)
|
||||||
|
return a
|
||||||
|
|
||||||
def sym_float(a):
|
def sym_float(a):
|
||||||
if hasattr(a, '__sym_float__'):
|
if isinstance(a, SymFloat):
|
||||||
return a.__sym_float__()
|
|
||||||
elif isinstance(a, torch._C.SymFloatNode):
|
|
||||||
return a
|
return a
|
||||||
|
elif hasattr(a, '__sym_float__'):
|
||||||
|
return a.__sym_float__()
|
||||||
return float(a)
|
return float(a)
|
||||||
|
|
||||||
def sym_int(a):
|
def sym_int(a):
|
||||||
if hasattr(a, '__sym_int__'):
|
if isinstance(a, SymInt):
|
||||||
return a.__sym_int__()
|
|
||||||
elif isinstance(a, torch._C.SymIntNode):
|
|
||||||
return a
|
return a
|
||||||
|
elif hasattr(a, '__sym_int__'):
|
||||||
|
return a.__sym_int__()
|
||||||
return int(a)
|
return int(a)
|
||||||
|
|
||||||
# TODO: An incomplete list
|
# TODO: An incomplete list
|
||||||
# 1. Set variables to be equal when we do equality
|
# 1. Set variables to be equal when we do equality
|
||||||
# 2. Specialize on 0/1 when we do subtraction
|
# 2. Specialize on 0/1 when we do subtraction
|
||||||
class PySymInt(object):
|
class SymNode:
|
||||||
"""
|
"""
|
||||||
PySymInt objects are the primary "symbolic shape" objects that flow through
|
This is a type erased SymInt/SymFloat which we use to do actual operations.
|
||||||
our program. They're what sit under FakeTensor, and contains our primary
|
End users don't touch this. Magic methods are NOT defined on this object.
|
||||||
implementation of symbolic shapes.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, expr, shape_env, constant=None):
|
def __init__(self, expr, shape_env, pytype, constant=None):
|
||||||
self._expr = expr
|
self._expr = expr
|
||||||
self.shape_env = shape_env
|
self.shape_env = shape_env
|
||||||
|
self.pytype = pytype
|
||||||
self.constant = constant
|
self.constant = constant
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -121,23 +128,49 @@ class PySymInt(object):
|
|||||||
self._update_expr()
|
self._update_expr()
|
||||||
return self._expr
|
return self._expr
|
||||||
|
|
||||||
def wrap(self, num):
|
|
||||||
return PySymInt(sympy.Integer(num), self.shape_env, constant=num)
|
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
return PySymInt(self.expr, self.shape_env, constant=self.constant)
|
|
||||||
|
|
||||||
def _update_expr(self):
|
def _update_expr(self):
|
||||||
self._expr = self.shape_env.replace(self._expr)
|
self._expr = self.shape_env.replace(self._expr)
|
||||||
|
|
||||||
def __str__(self):
|
def to_node(self, num):
|
||||||
|
if isinstance(num, (SymInt, SymFloat)):
|
||||||
|
return num.node
|
||||||
|
elif isinstance(num, int):
|
||||||
|
return self.wrap_int(num)
|
||||||
|
elif isinstance(num, float):
|
||||||
|
return self.wrap_float(num)
|
||||||
|
else:
|
||||||
|
# NotImplementedError is important so that Python tries the
|
||||||
|
# other magic method
|
||||||
|
raise NotImplementedError(type(num))
|
||||||
|
|
||||||
|
def is_int(self):
|
||||||
|
return self.pytype is int
|
||||||
|
|
||||||
|
def is_float(self):
|
||||||
|
return self.pytype is float
|
||||||
|
|
||||||
|
def wrap_int(self, num):
|
||||||
|
assert isinstance(num, int)
|
||||||
|
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
|
||||||
|
|
||||||
|
def wrap_float(self, num):
|
||||||
|
assert isinstance(num, float)
|
||||||
|
return SymNode(sympy.Integer(num), self.shape_env, float, constant=num)
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant)
|
||||||
|
|
||||||
|
def str(self):
|
||||||
return f"{self.expr}"
|
return f"{self.expr}"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.str()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.expr}"
|
return self.str()
|
||||||
|
|
||||||
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
|
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
|
||||||
def __int__(self):
|
def int_(self):
|
||||||
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
|
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
|
||||||
|
|
||||||
# You can manually trigger a guard with this function
|
# You can manually trigger a guard with this function
|
||||||
@ -146,28 +179,35 @@ class PySymInt(object):
|
|||||||
# guard occurred
|
# guard occurred
|
||||||
return int(self.shape_env.evaluate_expr(self.expr))
|
return int(self.shape_env.evaluate_expr(self.expr))
|
||||||
|
|
||||||
def __sym_float__(self):
|
def guard_float(self, file, line):
|
||||||
|
# TODO: use the file/line for some useful diagnostic on why a
|
||||||
|
# guard occurred
|
||||||
|
return float(self.shape_env.evaluate_expr(self.expr))
|
||||||
|
|
||||||
|
def sym_float(self):
|
||||||
if SYM_FUNCTION_MODE:
|
if SYM_FUNCTION_MODE:
|
||||||
return _handle_sym_dispatch(sym_float, (self,), {})
|
r = _handle_sym_dispatch(sym_float, (wrap_node(self),), {})
|
||||||
|
assert isinstance(r, (SymInt, SymFloat)), type(r)
|
||||||
|
return r.node
|
||||||
# TODO: consider constant prop here
|
# TODO: consider constant prop here
|
||||||
# TODO: wrapping the expr with sympy.Float doesn't seem to work, why
|
# TODO: wrapping the expr with sympy.Float doesn't seem to work, why
|
||||||
# not?
|
# not?
|
||||||
return PySymFloat(self.expr, self.shape_env)
|
return SymNode(self.expr, self.shape_env, float)
|
||||||
|
|
||||||
def __bool__(self):
|
def sym_int(self):
|
||||||
|
raise NotImplementedError("sym_int NYI")
|
||||||
|
"""
|
||||||
|
if SYM_FUNCTION_MODE:
|
||||||
|
return _handle_sym_dispatch(sym_int, (self,), {})
|
||||||
|
# TODO: consider constant prop here
|
||||||
|
# XXX: need to cast float to int in sympy; math.floor is wrong
|
||||||
|
# because negatives round to zero
|
||||||
|
return SymNode(self.expr, self.shape_env, int)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def bool_(self):
|
||||||
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
|
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
|
||||||
|
|
||||||
class PySymFloat:
|
|
||||||
def __init__(self, expr, shape_env, constant=None):
|
|
||||||
self.expr = expr
|
|
||||||
self.shape_env = shape_env
|
|
||||||
self.constant = constant
|
|
||||||
|
|
||||||
def wrap(self, num):
|
|
||||||
return PySymFloat(sympy.Float(num), self.shape_env, constant=num)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.expr}"
|
|
||||||
|
|
||||||
if HAS_SYMPY:
|
if HAS_SYMPY:
|
||||||
class FloorDiv(sympy.Function):
|
class FloorDiv(sympy.Function):
|
||||||
@ -238,32 +278,45 @@ unary_magic_methods = {
|
|||||||
|
|
||||||
float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"}
|
float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"}
|
||||||
|
|
||||||
def _make_magic(method, func, py_type):
|
def wrap_node(x):
|
||||||
|
if not isinstance(x, SymNode):
|
||||||
|
return x
|
||||||
|
if x.constant is not None:
|
||||||
|
return x.constant
|
||||||
|
if x.pytype is int:
|
||||||
|
return SymInt(x)
|
||||||
|
elif x.pytype is float:
|
||||||
|
return SymFloat(x)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"unrecognized return type {x.pytype}")
|
||||||
|
|
||||||
|
def _make_node_magic(method, func):
|
||||||
func = lru_cache(256)(func)
|
func = lru_cache(256)(func)
|
||||||
|
|
||||||
def magic_impl(self, other):
|
def binary_magic_impl(self, other):
|
||||||
if method in ["min", "max"]:
|
if method in ["min", "max"]:
|
||||||
op = getattr(builtins, method)
|
op = getattr(builtins, method)
|
||||||
else:
|
else:
|
||||||
op = getattr(operator, method)
|
op = getattr(operator, method)
|
||||||
if SYM_FUNCTION_MODE:
|
if SYM_FUNCTION_MODE:
|
||||||
return _handle_sym_dispatch(op, (self, other), {})
|
r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||||
if isinstance(other, py_type):
|
assert isinstance(r, (SymInt, SymFloat)), type(r)
|
||||||
other_expr = other.expr
|
return r.node
|
||||||
else:
|
assert isinstance(other, SymNode)
|
||||||
assert isinstance(other, sympy.Expr)
|
other_expr = other.expr
|
||||||
other_expr = other
|
|
||||||
# TODO: consider constant prop here
|
# TODO: consider constant prop here
|
||||||
expr = self.shape_env.replace(self.expr)
|
expr = self.shape_env.replace(self.expr)
|
||||||
other_expr = self.shape_env.replace(other_expr)
|
other_expr = self.shape_env.replace(other_expr)
|
||||||
out = func(expr, other_expr)
|
out = func(expr, other_expr)
|
||||||
out = sympy.expand(out)
|
out = sympy.expand(out)
|
||||||
if method in ["truediv"]:
|
if method in ["truediv"]:
|
||||||
return PySymFloat(out, self.shape_env)
|
pytype = float
|
||||||
else:
|
else:
|
||||||
# TODO: relational operators actually technically return a
|
pytype = self.pytype
|
||||||
# PySymBool, this is a type error
|
|
||||||
return py_type(out, self.shape_env)
|
# TODO: relational operators actually technically return a
|
||||||
|
# PySymBool, this is a type error
|
||||||
|
return SymNode(out, self.shape_env, pytype)
|
||||||
|
|
||||||
def unary_magic_impl(self):
|
def unary_magic_impl(self):
|
||||||
if SYM_FUNCTION_MODE:
|
if SYM_FUNCTION_MODE:
|
||||||
@ -271,33 +324,55 @@ def _make_magic(method, func, py_type):
|
|||||||
op = getattr(math, method)
|
op = getattr(math, method)
|
||||||
else:
|
else:
|
||||||
op = getattr(operator, method)
|
op = getattr(operator, method)
|
||||||
return _handle_sym_dispatch(op, (self,), {})
|
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
|
||||||
|
assert isinstance(r, (SymInt, SymFloat)), type(r)
|
||||||
|
return r.node
|
||||||
# TODO: consider constant prop here
|
# TODO: consider constant prop here
|
||||||
expr = self.shape_env.replace(self.expr)
|
expr = self.shape_env.replace(self.expr)
|
||||||
out = func(expr)
|
out = func(expr)
|
||||||
out = sympy.expand(out)
|
out = sympy.expand(out)
|
||||||
if method in ["ceil", "floor"]:
|
if method in ["ceil", "floor"]:
|
||||||
return PySymInt(out, self.shape_env)
|
pytype = int
|
||||||
else:
|
else:
|
||||||
return py_type(out, self.shape_env)
|
pytype = self.pytype
|
||||||
|
|
||||||
|
return SymNode(out, self.shape_env, pytype)
|
||||||
|
|
||||||
# this should be wrapped transparently into torch.SymIntNode
|
|
||||||
if method in unary_magic_methods:
|
if method in unary_magic_methods:
|
||||||
setattr(py_type, method, unary_magic_impl)
|
setattr(SymNode, method, unary_magic_impl)
|
||||||
setattr(py_type, f"__{method}__", unary_magic_impl)
|
|
||||||
else:
|
else:
|
||||||
setattr(py_type, method, magic_impl)
|
setattr(SymNode, method, binary_magic_impl)
|
||||||
setattr(py_type, f"__{method}__", magic_impl)
|
|
||||||
if method in reflectable_magic_methods:
|
|
||||||
setattr(py_type, f"__r{method}__", magic_impl)
|
|
||||||
|
|
||||||
for method, func in magic_methods.items():
|
for method, func in magic_methods.items():
|
||||||
_make_magic(method, func, PySymInt)
|
_make_node_magic(method, func)
|
||||||
|
|
||||||
|
def _make_user_magic(method, user_type):
|
||||||
|
# User magic takes care of wrapping the other operand into a node,
|
||||||
|
# so that our internal logic can assume everything is nodes
|
||||||
|
|
||||||
|
def unary_magic_impl(self):
|
||||||
|
return wrap_node(getattr(self.node, method)())
|
||||||
|
|
||||||
|
def binary_magic_impl(self, other):
|
||||||
|
return wrap_node(getattr(self.node, method)(self.node.to_node(other)))
|
||||||
|
|
||||||
|
def rbinary_magic_impl(self, other):
|
||||||
|
return wrap_node(getattr(self.node.to_node(other), method)(self.node))
|
||||||
|
|
||||||
|
if method in unary_magic_methods:
|
||||||
|
setattr(user_type, f"__{method}__", unary_magic_impl)
|
||||||
|
else:
|
||||||
|
setattr(user_type, f"__{method}__", binary_magic_impl)
|
||||||
|
if method in reflectable_magic_methods:
|
||||||
|
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
|
||||||
|
|
||||||
|
for method, func in magic_methods.items():
|
||||||
|
_make_user_magic(method, SymInt)
|
||||||
|
|
||||||
for method, func in magic_methods.items():
|
for method, func in magic_methods.items():
|
||||||
if method not in float_magic_methods:
|
if method not in float_magic_methods:
|
||||||
continue
|
continue
|
||||||
_make_magic(method, func, PySymFloat)
|
_make_user_magic(method, SymFloat)
|
||||||
|
|
||||||
del method
|
del method
|
||||||
del func
|
del func
|
||||||
@ -390,9 +465,7 @@ class ShapeEnv(object):
|
|||||||
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
|
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
|
||||||
|
|
||||||
def create_symintnode(self, expr: "sympy.Expr"):
|
def create_symintnode(self, expr: "sympy.Expr"):
|
||||||
py_sym_int = PySymInt(expr, self)
|
return SymInt(SymNode(expr, self, int))
|
||||||
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
|
|
||||||
return cpp_sym_int
|
|
||||||
|
|
||||||
def create_symbol(self, val: int) -> "sympy.Expr":
|
def create_symbol(self, val: int) -> "sympy.Expr":
|
||||||
if not HAS_SYMPY:
|
if not HAS_SYMPY:
|
||||||
|
@ -498,7 +498,7 @@ class CodeGen(object):
|
|||||||
if isinstance(meta_val, FakeTensor):
|
if isinstance(meta_val, FakeTensor):
|
||||||
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
||||||
elif isinstance(meta_val, py_sym_types):
|
elif isinstance(meta_val, py_sym_types):
|
||||||
maybe_type_annotation = f': Sym({meta_val.expr})'
|
maybe_type_annotation = f': Sym({meta_val})'
|
||||||
elif isinstance(meta_val, TensorMetadata):
|
elif isinstance(meta_val, TensorMetadata):
|
||||||
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user