Unify SymIntNode and SymFloatNode into SymNode (#87817)

This refactor was prompted by challenges handling mixed int/float
operations in C++.  A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/  This PR takes a different
approach.

The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode.  This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods.  This has a number of
knock on effects.

- We no longer have C++ classes to bind to Python.  Instead, we take an
  entirely new approach to our Python API, where we have a SymInt/SymFloat
  class defined entirely in Python, which hold a SymNode (which corresponds
  to the C++ SymNode).  However, SymNode is not pybind11-bound; instead,
  it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
  when it goes into C++.  This implies a userland rename.

  In principle, it is also possible for the canonical implementation of SymNode
  to be written in C++, and then bound to Python with pybind11 (we have
  this code, although it is commented out.)  However, I did not implement
  this as we currently have no C++ implementations of SymNode.

  Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
  code needs to know how to find these classes.  Currently, this is done
  just by manually importing torch and getting the attributes.

- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
  takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
  __torch_dispatch__ works.

Some miscellaneous improvements:

- SymInt now has a constructor that takes SymNode.  Note that this
  constructor is ambiguous if you pass in a subclass of SymNode,
  so an explicit downcast is necessary.  This means toSymFloat/toSymInt
  are no more.  This is a mild optimization as it means rvalue reference
  works automatically.

- We uniformly use the caster for c10::SymInt/SymFloat, rather than
  going the long way via the SymIntNode/SymFloatNode.

- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
  functions, pretty sure this doesn't do anything.

- guard_int is now a free function, since to guard on an int you cannot
  assume the method exists.  A function can handle both int and SymInt
  inputs.

- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
  ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
  plain methods; this is to help avoid confusion between the two types.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
This commit is contained in:
Edward Z. Yang
2022-10-27 13:49:11 -07:00
committed by PyTorch MergeBot
parent 2205f56f46
commit 1ff52225f1
54 changed files with 732 additions and 1439 deletions

View File

@ -1 +1 @@
095ee628212f0235ad0d6908bdd514123639fc86 1e9b8bdc75114ac6c16305c970be37a1cd2fdb1c

View File

@ -439,7 +439,7 @@ command = [
"""--error-description=\ """--error-description=\
This line has an isinstance call that directly refers to \ This line has an isinstance call that directly refers to \
int or float. This is error-prone because you may also \ int or float. This is error-prone because you may also \
have wanted to allow SymIntNode or SymFloatNode in your test. \ have wanted to allow SymInt or SymFloat in your test. \
To suppress this lint, use an appropriate type alias defined \ To suppress this lint, use an appropriate type alias defined \
in torch._prims_common; use IntLike/FloatLike when you would accept \ in torch._prims_common; use IntLike/FloatLike when you would accept \
both regular and symbolic numbers, Dim for ints representing \ both regular and symbolic numbers, Dim for ints representing \

View File

@ -95,7 +95,7 @@ c10::SymInt get_nbytes(const Tensor& value) {
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) { if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
// Today, the two implementations of SymInt are in Python (proxy tensor), // Today, the two implementations of SymInt are in Python (proxy tensor),
// and lazy tensor (LTC/XLA). // and lazy tensor (LTC/XLA).
// LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl). // LTC hasn't implemented SymInt support yet though
// Once it does, we should remove this check. // Once it does, we should remove this check.
if (value.key_set().has(c10::DispatchKey::Python)) { if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes(); return value.storage().sym_nbytes();

View File

@ -562,7 +562,7 @@ public:
IValue(c10::SymInt i) { IValue(c10::SymInt i) {
if (i.is_symbolic()) { if (i.is_symbolic()) {
tag = Tag::SymInt; tag = Tag::SymInt;
payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release(); payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
} else { } else {
tag = Tag::Int; tag = Tag::Int;
payload.u.as_int = i.as_int_unchecked(); payload.u.as_int = i.as_int_unchecked();
@ -578,7 +578,7 @@ public:
IValue(c10::SymFloat i) { IValue(c10::SymFloat i) {
if (i.is_symbolic()) { if (i.is_symbolic()) {
tag = Tag::SymFloat; tag = Tag::SymFloat;
payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release(); payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
} else { } else {
tag = Tag::Double; tag = Tag::Double;
payload.u.as_double = i.as_float_unchecked(); payload.u.as_double = i.as_float_unchecked();
@ -812,10 +812,10 @@ public:
// for both SymFloat and double // for both SymFloat and double
if (s.isSymInt()) { if (s.isSymInt()) {
tag = Tag::SymInt; tag = Tag::SymInt;
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release(); payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
} else if (s.isSymFloat()) { } else if (s.isSymFloat()) {
tag = Tag::SymFloat; tag = Tag::SymFloat;
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release(); payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
} else if (s.isFloatingPoint()) { } else if (s.isFloatingPoint()) {
tag = Tag::Double; tag = Tag::Double;
payload.u.as_double = s.toDouble(); payload.u.as_double = s.toDouble();

View File

@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& {
inline c10::SymInt IValue::toSymInt() const { inline c10::SymInt IValue::toSymInt() const {
AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
if (isSymInt()) { if (isSymInt()) {
return c10::SymInt::toSymInt(toIntrusivePtr<c10::SymIntNodeImpl>()); return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
} else { } else {
return c10::SymInt(payload.u.as_int); return c10::SymInt(payload.u.as_int);
} }
@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const {
inline c10::SymFloat IValue::toSymFloat() const { inline c10::SymFloat IValue::toSymFloat() const {
AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
if (isSymFloat()) { if (isSymFloat()) {
return c10::SymFloat::toSymFloat(toIntrusivePtr<c10::SymFloatNodeImpl>()); return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
} else { } else {
return c10::SymFloat(payload.u.as_double); return c10::SymFloat(payload.u.as_double);
} }

View File

@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type {
return "SymInt"; return "SymInt";
} }
std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
// TODO: will become a Union[SymIntNodeImpl|int] in the near future
return "int"; return "int";
} }
static const TypeKind Kind = TypeKind::SymIntType; static const TypeKind Kind = TypeKind::SymIntType;

View File

@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) {
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1)))); ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt()))); ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
} }
TEST(TestSymInt, Basic) {
Scalar foo;
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
foo = Scalar(a_impl->toSymInt());
ASSERT_EQ(a_impl.use_count(), 2);
Scalar bar{foo};
ASSERT_EQ(a_impl.use_count(), 3);
auto baz = bar;
ASSERT_EQ(a_impl.use_count(), 4);
auto foo2 = std::move(bar);
ASSERT_EQ(a_impl.use_count(), 4);
ASSERT_TRUE(foo2.isSymInt());
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_TRUE(bar.isIntegral(false));
foo2 = SymInt(4);
ASSERT_FALSE(foo2.isSymInt());
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
foo2 = foo2;
ASSERT_FALSE(foo2.isSymInt());
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
ASSERT_EQ(a_impl.use_count(), 3);
ASSERT_THROW(foo.to<double>(), c10::Error);
Scalar int_s = 3;
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
}

View File

@ -958,6 +958,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/object_ptr.cpp",
"torch/csrc/utils/python_arg_parser.cpp", "torch/csrc/utils/python_arg_parser.cpp",
"torch/csrc/utils/python_dispatch.cpp", "torch/csrc/utils/python_dispatch.cpp",
"torch/csrc/utils/python_symnode.cpp",
"torch/csrc/utils/structseq.cpp", "torch/csrc/utils/structseq.cpp",
"torch/csrc/utils/tensor_apply.cpp", "torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_dtypes.cpp", "torch/csrc/utils/tensor_dtypes.cpp",

View File

@ -92,8 +92,8 @@ class C10_API Scalar {
SymInt toSymInt() const { SymInt toSymInt() const {
if (Tag::HAS_si == tag) { if (Tag::HAS_si == tag) {
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy( return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
static_cast<SymIntNodeImpl*>(v.p))); static_cast<SymNodeImpl*>(v.p)));
} else { } else {
return toLong(); return toLong();
} }
@ -101,9 +101,8 @@ class C10_API Scalar {
SymFloat toSymFloat() const { SymFloat toSymFloat() const {
if (Tag::HAS_sd == tag) { if (Tag::HAS_sd == tag) {
return c10::SymFloat::toSymFloat( return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy( static_cast<SymNodeImpl*>(v.p)));
static_cast<SymFloatNodeImpl*>(v.p)));
} else { } else {
return toDouble(); return toDouble();
} }

View File

@ -1,32 +1,27 @@
#include <c10/core/SymFloat.h> #include <c10/core/SymFloat.h>
#include <c10/core/SymFloatNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <array> #include <array>
namespace c10 { namespace c10 {
SymFloatNode SymFloat::toSymFloatNodeImpl() const { SymNode SymFloat::toSymNodeImpl() const {
TORCH_CHECK(is_symbolic()); TORCH_CHECK(is_symbolic());
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned()); return SymNode::reclaim_copy(toSymNodeImplUnowned());
} }
static std::array<SymFloatNode, 2> normalize_symfloats( static std::array<SymNode, 2> normalize_symfloats(SymFloat a_, SymFloat b_) {
SymFloat a_, SymNode a, b;
SymFloat b_) {
SymFloatNode a, b;
if (a_.is_symbolic()) if (a_.is_symbolic())
a = a_.toSymFloatNodeImpl(); a = a_.toSymNodeImpl();
if (b_.is_symbolic()) if (b_.is_symbolic())
b = b_.toSymFloatNodeImpl(); b = b_.toSymNodeImpl();
SymFloatNodeImpl* common = a ? a.get() : b.get(); SymNodeImpl* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match
if (!a) { if (!a) {
a = common->wrap(a_.as_float_unchecked()); a = common->wrap_float(a_.as_float_unchecked());
a_.toSymFloat(a); //
} }
if (!b) { if (!b) {
b = common->wrap(b_.as_float_unchecked()); b = common->wrap_float(b_.as_float_unchecked());
b_.toSymFloat(b);
} }
return {a, b}; return {a, b};
} }
@ -36,7 +31,7 @@ SymFloat SymFloat::operator+(SymFloat sci) const {
return SymFloat(data_ + sci.data_); return SymFloat(data_ + sci.data_);
} }
auto res = normalize_symfloats(*this, sci); auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->add(res[1])); return SymFloat(res[0]->add(res[1]));
} }
SymFloat SymFloat::operator-(SymFloat sci) const { SymFloat SymFloat::operator-(SymFloat sci) const {
@ -44,7 +39,7 @@ SymFloat SymFloat::operator-(SymFloat sci) const {
return SymFloat(data_ - sci.data_); return SymFloat(data_ - sci.data_);
} }
auto res = normalize_symfloats(*this, sci); auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->sub(res[1])); return SymFloat(res[0]->sub(res[1]));
} }
SymFloat SymFloat::operator*(SymFloat sci) const { SymFloat SymFloat::operator*(SymFloat sci) const {
@ -52,7 +47,7 @@ SymFloat SymFloat::operator*(SymFloat sci) const {
return SymFloat(data_ * sci.data_); return SymFloat(data_ * sci.data_);
} }
auto res = normalize_symfloats(*this, sci); auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->mul(res[1])); return SymFloat(res[0]->mul(res[1]));
} }
SymFloat SymFloat::operator/(SymFloat sci) const { SymFloat SymFloat::operator/(SymFloat sci) const {
@ -60,16 +55,12 @@ SymFloat SymFloat::operator/(SymFloat sci) const {
return SymFloat(data_ / sci.data_); return SymFloat(data_ / sci.data_);
} }
auto res = normalize_symfloats(*this, sci); auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->truediv(res[1])); return SymFloat(res[0]->truediv(res[1]));
}
c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
return c10::SymFloat(std::move(sin_sp));
} }
std::ostream& operator<<(std::ostream& os, SymFloat s) { std::ostream& operator<<(std::ostream& os, SymFloat s) {
if (s.is_symbolic()) { if (s.is_symbolic()) {
os << s.toSymFloatNodeImpl()->str(); os << s.toSymNodeImpl()->str();
} else { } else {
os << s.as_float_unchecked(); os << s.as_float_unchecked();
} }

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <c10/core/SymFloatNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
@ -14,20 +14,21 @@ namespace c10 {
class C10_API SymFloat { class C10_API SymFloat {
public: public:
/*implicit*/ SymFloat(double d) : data_(d){}; /*implicit*/ SymFloat(double d) : data_(d){};
SymFloat(SymFloatNode ptr) SymFloat(SymNode ptr)
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)){}; : data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_float());
};
SymFloat() : data_(0.0) {} SymFloat() : data_(0.0) {}
SymFloatNodeImpl* toSymFloatNodeImplUnowned() const { SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get(); return ptr_.get();
} }
SymFloatNodeImpl* release() && { SymNodeImpl* release() && {
return std::move(ptr_).release(); return std::move(ptr_).release();
} }
SymFloatNode toSymFloatNodeImpl() const; SymNode toSymNodeImpl() const;
static c10::SymFloat toSymFloat(SymFloatNode sin);
double expect_float() const { double expect_float() const {
TORCH_CHECK(!is_symbolic()); TORCH_CHECK(!is_symbolic());
@ -53,7 +54,7 @@ class C10_API SymFloat {
private: private:
// TODO: optimize to union // TODO: optimize to union
double data_; double data_;
SymFloatNode ptr_; SymNode ptr_;
}; };
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s); C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);

View File

@ -1,20 +0,0 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/core/SymIntNodeImpl.h>
namespace c10 {
c10::SymFloat SymFloatNodeImpl::toSymFloat() {
auto sit_sp = SymFloatNode::reclaim_copy(this);
return SymFloat::toSymFloat(sit_sp);
}
c10::SymIntNode SymFloatNodeImpl::ceil() {
TORCH_CHECK(false, "NYI");
}
c10::SymIntNode SymFloatNodeImpl::floor() {
TORCH_CHECK(false, "NYI");
}
} // namespace c10

View File

@ -1,76 +0,0 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <memory>
#include <mutex>
#include <vector>
namespace c10 {
class SymIntNodeImpl;
using SymIntNode = c10::intrusive_ptr<SymIntNodeImpl>;
class SymFloat;
class SymFloatNodeImpl;
using SymFloatNode = c10::intrusive_ptr<SymFloatNodeImpl>;
class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target {
public:
c10::SymFloat toSymFloat();
virtual ~SymFloatNodeImpl(){};
template <typename T>
c10::intrusive_ptr<T> dyn_cast() const {
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
}
virtual SymFloatNode wrap(double num) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode add(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode sub(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode mul(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode truediv(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode pow(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode eq(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode ne(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode gt(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode lt(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode le(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode ge(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode ceil();
virtual SymIntNode floor();
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
};
};
} // namespace c10

View File

@ -1,47 +1,46 @@
#include <c10/core/SymFloat.h> #include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h> #include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <array> #include <array>
namespace c10 { namespace c10 {
static std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) { static std::array<SymNode, 2> normalize_symints(SymInt a_, SymInt b_) {
SymIntNode a, b; SymNode a, b;
if (a_.is_symbolic()) if (a_.is_symbolic())
a = a_.toSymIntNodeImpl(); a = a_.toSymNodeImpl();
if (b_.is_symbolic()) if (b_.is_symbolic())
b = b_.toSymIntNodeImpl(); b = b_.toSymNodeImpl();
SymIntNodeImpl* common = a ? a.get() : b.get(); SymNodeImpl* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match // TODO: technically we need to check that the classes match
if (!a) { if (!a) {
a = common->wrap(a_.as_int_unchecked()); a = common->wrap_int(a_.as_int_unchecked());
a_.toSymInt(a); //
} }
if (!b) { if (!b) {
b = common->wrap(b_.as_int_unchecked()); b = common->wrap_int(b_.as_int_unchecked());
b_.toSymInt(b);
} }
return {a, b}; return {a, b};
} }
SymIntNode SymInt::toSymIntNodeImpl() const { SymNode SymInt::toSymNodeImpl() const {
TORCH_CHECK(is_symbolic()); TORCH_CHECK(is_symbolic());
return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned()); return SymNode::reclaim_copy(toSymNodeImplUnowned());
} }
c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) { SymInt::SymInt(SymNode sin_sp) {
TORCH_CHECK(sin_sp->is_int());
auto ptr = static_cast<uint64_t>( auto ptr = static_cast<uint64_t>(
reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release()))); reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
auto rep = (ptr & ~MASK) | IS_SYM; auto rep = (ptr & ~MASK) | IS_SYM;
return c10::SymInt(UNCHECKED, static_cast<int64_t>(rep)); data_ = static_cast<int64_t>(rep);
} }
int64_t SymInt::guard_int(const char* file, int64_t line) const { int64_t SymInt::guard_int(const char* file, int64_t line) const {
if (!is_symbolic()) { if (!is_symbolic()) {
return data_; return data_;
} }
SymIntNode a = toSymIntNodeImpl(); SymNode a = toSymNodeImpl();
return a->guard_int(file, line); return a->guard_int(file, line);
} }
@ -49,7 +48,7 @@ SymInt::operator SymFloat() const {
if (!is_symbolic()) { if (!is_symbolic()) {
return SymFloat(double(data_)); return SymFloat(double(data_));
} }
return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float()); return SymFloat(toSymNodeImpl()->sym_float());
} }
SymInt SymInt::operator+(SymInt sci) const { SymInt SymInt::operator+(SymInt sci) const {
@ -57,7 +56,7 @@ SymInt SymInt::operator+(SymInt sci) const {
return SymInt(data_ + sci.data_); return SymInt(data_ + sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->add(res[1])); return SymInt(res[0]->add(res[1]));
} }
SymInt SymInt::operator-(SymInt sci) const { SymInt SymInt::operator-(SymInt sci) const {
@ -65,7 +64,7 @@ SymInt SymInt::operator-(SymInt sci) const {
return SymInt(data_ - sci.data_); return SymInt(data_ - sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->sub(res[1])); return SymInt(res[0]->sub(res[1]));
} }
SymInt SymInt::operator*(SymInt sci) const { SymInt SymInt::operator*(SymInt sci) const {
@ -73,7 +72,7 @@ SymInt SymInt::operator*(SymInt sci) const {
return SymInt(data_ * sci.data_); return SymInt(data_ * sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->mul(res[1])); return SymInt(res[0]->mul(res[1]));
} }
SymInt SymInt::operator/(SymInt sci) const { SymInt SymInt::operator/(SymInt sci) const {
@ -81,7 +80,7 @@ SymInt SymInt::operator/(SymInt sci) const {
return SymInt(data_ / sci.data_); return SymInt(data_ / sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->floordiv(res[1])); return SymInt(res[0]->floordiv(res[1]));
} }
SymInt SymInt::operator%(SymInt sci) const { SymInt SymInt::operator%(SymInt sci) const {
@ -89,7 +88,7 @@ SymInt SymInt::operator%(SymInt sci) const {
return SymInt(data_ % sci.data_); return SymInt(data_ % sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->mod(res[1])); return SymInt(res[0]->mod(res[1]));
} }
bool SymInt::operator==(SymInt sci) const { bool SymInt::operator==(SymInt sci) const {
@ -141,14 +140,14 @@ SymInt SymInt::min(SymInt sci) const {
return std::min(data_, sci.data_); return std::min(data_, sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->min(res[1])); return SymInt(res[0]->min(res[1]));
} }
SymInt SymInt::max(SymInt sci) const { SymInt SymInt::max(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) { if (!is_symbolic() && !sci.is_symbolic()) {
return std::max(data_, sci.data_); return std::max(data_, sci.data_);
} }
auto res = normalize_symints(*this, sci); auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->max(res[1])); return SymInt(res[0]->max(res[1]));
} }
void SymInt::operator*=(SymInt sci) { void SymInt::operator*=(SymInt sci) {
@ -193,7 +192,7 @@ SymInt SymInt::operator*(int64_t sci) const {
std::ostream& operator<<(std::ostream& os, SymInt s) { std::ostream& operator<<(std::ostream& os, SymInt s) {
if (s.is_symbolic()) { if (s.is_symbolic()) {
os << s.toSymIntNodeImpl()->str(); os << s.toSymNodeImpl()->str();
} else { } else {
os << s.as_int_unchecked(); os << s.as_int_unchecked();
} }
@ -202,7 +201,7 @@ std::ostream& operator<<(std::ostream& os, SymInt s) {
SymInt operator-(SymInt s) { SymInt operator-(SymInt s) {
if (s.is_symbolic()) { if (s.is_symbolic()) {
return SymInt::toSymInt(s.toSymIntNodeImpl()->neg()); return SymInt(s.toSymNodeImpl()->neg());
} else { } else {
return SymInt(-s.as_int_unchecked()); return SymInt(-s.as_int_unchecked());
} }

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
@ -12,24 +12,19 @@ namespace c10 {
class SymFloat; class SymFloat;
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to // SymInt represents either a regular int64_t, or a symbolic integer
// represent concrete dimension values. // (represented in a type erased way as SymNode). The intention is for SymInt
// to represent symbolic sizes that arise when doing shape computation in
// operator kernels. This allows for tracing through programs without baking in
// concrete sizes into kernel calls.
// //
// `SymInt` is also a data type in Pytorch that can be used in function schemas // SymInt has an API equivalent to int64_t. In particular, it is a value type.
// to enable tracing. // Internally, SymInt is represented in a clever packed way, so that it only
// occupies one word of space; but morally, it is a union between an int64_t
// and an intrusive pointer to SymNodeImpl.
// //
// `SymInt` is introduced to enable tracing arithmetic // Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will // is_int() returns true
// allow LTC and AOTAutograd representing dynamic shapes in expression graphs
// faithfully without baking in concrete dimension values.
//
// To trace the operations, SymInt will overload arithmetic operators (e.g. +,
// -, *) and will provide overloads taking SymInt for commonly used math
// functions.
//
// SymInt will be extenteded to represent a union structure Union[int64_t,
// SymIntNodeImpl*] which will be implemented as a single packed int64_t field
// named data_.
class C10_API SymInt { class C10_API SymInt {
public: public:
@ -44,6 +39,7 @@ class C10_API SymInt {
TORCH_CHECK(!is_symbolic()); TORCH_CHECK(!is_symbolic());
}; };
SymInt() : data_(0) {} SymInt() : data_(0) {}
SymInt(SymNode n);
// unchecked c-tor accepting raw `data_` // unchecked c-tor accepting raw `data_`
// One appropriate use for this is when you are constructing a symint // One appropriate use for this is when you are constructing a symint
@ -55,7 +51,7 @@ class C10_API SymInt {
// temporary and then use the move constructor/assignment // temporary and then use the move constructor/assignment
SymInt(const SymInt& s) : data_(0) { SymInt(const SymInt& s) : data_(0) {
if (s.is_symbolic()) { if (s.is_symbolic()) {
*this = SymInt::toSymInt(s.toSymIntNodeImpl()); *this = SymInt(s.toSymNodeImpl());
} else { } else {
data_ = s.data_; data_ = s.data_;
} }
@ -67,7 +63,7 @@ class C10_API SymInt {
SymInt& operator=(const SymInt& s) { SymInt& operator=(const SymInt& s) {
if (this != &s) { if (this != &s) {
if (s.is_symbolic()) { if (s.is_symbolic()) {
*this = SymInt::toSymInt(s.toSymIntNodeImpl()); *this = SymInt(s.toSymNodeImpl());
} else { } else {
data_ = s.data_; data_ = s.data_;
} }
@ -76,7 +72,7 @@ class C10_API SymInt {
} }
SymInt& operator=(SymInt&& s) { SymInt& operator=(SymInt&& s) {
if (this != &s) { if (this != &s) {
release_(); // release the current SymIntNode if any release_(); // release the current SymNode if any
data_ = s.data_; data_ = s.data_;
if (s.is_symbolic()) if (s.is_symbolic())
s.data_ = 0; s.data_ = 0;
@ -86,31 +82,31 @@ class C10_API SymInt {
SymInt clone() const { SymInt clone() const {
if (is_symbolic()) { if (is_symbolic()) {
return toSymIntNodeImplUnowned()->clone()->toSymInt(); return SymInt(toSymNodeImplUnowned()->clone());
} }
return *this; return *this;
} }
SymIntNodeImpl* toSymIntNodeImplUnowned() const { SymNodeImpl* toSymNodeImplUnowned() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK; uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
uint64_t sign_bit_mask = 1ULL << (62 - 1); uint64_t sign_bit_mask = 1ULL << (62 - 1);
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
return static_cast<SymIntNodeImpl*>( return static_cast<SymNodeImpl*>(
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits))); reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
} }
void release_() { void release_() {
if (is_symbolic()) { if (is_symbolic()) {
SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal SymNode::reclaim(toSymNodeImplUnowned()); // steal
} }
} }
SymIntNodeImpl* release() && { SymNodeImpl* release() && {
#ifndef C10_MOBILE #ifndef C10_MOBILE
TORCH_INTERNAL_ASSERT(is_symbolic()); TORCH_INTERNAL_ASSERT(is_symbolic());
auto* r = toSymIntNodeImplUnowned(); auto* r = toSymNodeImplUnowned();
data_ = 0; // transfer ownership data_ = 0; // transfer ownership
return r; return r;
#else #else
@ -118,8 +114,7 @@ class C10_API SymInt {
#endif #endif
} }
SymIntNode toSymIntNodeImpl() const; SymNode toSymNodeImpl() const;
static c10::SymInt toSymInt(SymIntNode sin);
~SymInt() { ~SymInt() {
release_(); release_();

View File

@ -1,11 +0,0 @@
#include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h>
namespace c10 {
c10::SymInt SymIntNodeImpl::toSymInt() {
auto sit_sp = SymIntNode::reclaim_copy(this);
return SymInt::toSymInt(sit_sp);
}
} // namespace c10

3
c10/core/SymNodeImpl.cpp Normal file
View File

@ -0,0 +1,3 @@
#include <c10/core/SymNodeImpl.h>
namespace c10 {} // namespace c10

View File

@ -1,6 +1,5 @@
#pragma once #pragma once
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
@ -10,13 +9,12 @@
namespace c10 { namespace c10 {
class SymInt; class SymNodeImpl;
class SymIntNodeImpl; using SymNode = c10::intrusive_ptr<SymNodeImpl>;
class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target { class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
public: public:
c10::SymInt toSymInt(); virtual ~SymNodeImpl(){};
virtual ~SymIntNodeImpl(){};
template <typename T> template <typename T>
c10::intrusive_ptr<T> dyn_cast() const { c10::intrusive_ptr<T> dyn_cast() const {
@ -24,66 +22,87 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
} }
// these could be pure virtual when we implement LTC versions // these could be pure virtual when we implement LTC versions
virtual SymIntNode add(const SymIntNode& other) { virtual bool is_int() {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode sub(const SymIntNode& other) { virtual bool is_float() {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode mul(const SymIntNode& other) { virtual SymNode add(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymFloatNode truediv(const SymIntNode& other) { virtual SymNode sub(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode floordiv(const SymIntNode& other) { virtual SymNode mul(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode mod(const SymIntNode& other) { virtual SymNode truediv(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode eq(const SymIntNode& other) { virtual SymNode pow(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode ne(const SymIntNode& other) { virtual SymNode floordiv(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode gt(const SymIntNode& other) { virtual SymNode mod(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode lt(const SymIntNode& other) { virtual SymNode eq(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode le(const SymIntNode& other) { virtual SymNode ne(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode ge(const SymIntNode& other) { virtual SymNode gt(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode ceil() { virtual SymNode lt(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode neg() { virtual SymNode le(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode min(const SymIntNode& other) { virtual SymNode ge(const SymNode& other) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode max(const SymIntNode& other) { virtual SymNode ceil() {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymIntNode clone() { virtual SymNode floor() {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual SymFloatNode sym_float() { virtual SymNode neg() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode min(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode max(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode clone() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_int() {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
} }
virtual SymIntNode wrap(int64_t num) { virtual SymNode sym_float() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode wrap_int(int64_t num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_float(double num) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual int64_t guard_int(const char* file, int64_t line) { virtual int64_t guard_int(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };
virtual double guard_float(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual int64_t int_() { virtual int64_t int_() {
TORCH_CHECK(false, "NYI"); TORCH_CHECK(false, "NYI");
}; };

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <c10/core/SymInt.h> #include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
using namespace c10; using namespace c10;
#ifndef C10_MOBILE #ifndef C10_MOBILE
@ -20,12 +20,6 @@ TEST(SymIntTest, ConcreteInts) {
check(-4611686018427387904LL); check(-4611686018427387904LL);
} }
TEST(SymIntTest, AddNode) {
auto n = c10::make_intrusive<SymIntNodeImpl>();
auto i = n->toSymInt();
EXPECT_TRUE(i.is_symbolic());
}
TEST(SymIntTest, CheckRange) { TEST(SymIntTest, CheckRange) {
EXPECT_FALSE(SymInt::check_range(INT64_MIN)); EXPECT_FALSE(SymInt::check_range(INT64_MIN));
} }

View File

@ -335,8 +335,8 @@ coverage_ignore_classes = [
"Quantize", "Quantize",
# torch.utils.backcompat # torch.utils.backcompat
"Warning", "Warning",
"SymIntNode", "SymInt",
"SymFloatNode", "SymFloat",
] ]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.

View File

@ -605,7 +605,7 @@ class PytreeThunk:
return x return x
return pytree.tree_unflatten(x, self.spec) return pytree.tree_unflatten(x, self.spec)
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymIntNode, torch.SymFloatNode] KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymInt, torch.SymFloat]
def aot_function( def aot_function(

View File

@ -209,7 +209,7 @@ def _tensor_nbytes(numel, dtype):
def _size_of(node: fx.Node) -> int: def _size_of(node: fx.Node) -> int:
def to_size_hint(s): def to_size_hint(s):
if isinstance(s, torch.SymIntNode): if isinstance(s, torch.SymInt):
py_s = s.get_pyobj() py_s = s.get_pyobj()
return py_s.shape_env.size_hint(py_s.expr) return py_s.shape_env.size_hint(py_s.expr)
assert isinstance(s, int) assert isinstance(s, int)

View File

@ -18,6 +18,8 @@ cond = PyOperator('cond')
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
def _unwrap_proxy(e): def _unwrap_proxy(e):
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
return e
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
assert isinstance(operands, list), "Cond operands must be a list of tensors" assert isinstance(operands, list), "Cond operands must be a list of tensors"

View File

@ -1447,35 +1447,29 @@ TEST(TestSymInt, AddSymbolicInt) {
} }
#ifndef C10_MOBILE #ifndef C10_MOBILE
TEST(TestSymInt, TestIntrusive) { class TestSymNodeImpl : public c10::SymNodeImpl {
auto a = c10::make_intrusive<c10::SymIntNodeImpl>();
auto b = c10::make_intrusive<c10::SymIntNodeImpl>();
ASSERT_EQ(a.use_count(), 1);
ASSERT_EQ(b.use_count(), 1);
auto as = a->toSymInt();
auto bs = b->toSymInt();
ASSERT_EQ(a.use_count(), 2);
ASSERT_EQ(b.use_count(), 2);
as = bs;
ASSERT_EQ(a.use_count(), 1);
ASSERT_EQ(b.use_count(), 3);
}
class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
public: public:
TestSymIntNodeImpl(int64_t i) : i_(i) {} explicit TestSymNodeImpl(int64_t i) : i_(i) {}
bool is_int() override {
return true;
};
bool is_float() override {
return false;
};
bool bool_() override { bool bool_() override {
return static_cast<bool>(i_); return static_cast<bool>(i_);
}; };
#define OPDEF3(NAME, OP, RET) \ #define OPDEF3(NAME, OP, RET) \
RET NAME(const c10::SymIntNode& other) override { \ RET NAME(const c10::SymNode& other) override { \
return make_intrusive<TestSymIntNodeImpl>( \ return make_intrusive<TestSymNodeImpl>( \
this->i_ OP dynamic_cast<TestSymIntNodeImpl*>(other.get())->i_); \ this->i_ OP dynamic_cast<TestSymNodeImpl*>(other.get())->i_); \
} }
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymIntNode) #define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode)
OPDEF2(add, +) OPDEF2(add, +)
OPDEF2(sub, -) OPDEF2(sub, -)
OPDEF2(mul, *) OPDEF2(mul, *)
@ -1494,17 +1488,19 @@ class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
int64_t i_; int64_t i_;
}; };
TEST(TestSymInt, TestSymIntToSymIntNodeDispatch) { TEST(TestSymInt, TestSymIntToSymNodeDispatch) {
auto get = [](c10::SymInt si) { auto get = [](c10::SymInt si) {
auto node = si.toSymIntNodeImpl(); auto node = si.toSymNodeImpl();
return dynamic_cast<TestSymIntNodeImpl*>(node.get())->i_; return dynamic_cast<TestSymNodeImpl*>(node.get())->i_;
}; };
std::vector<int64_t> inputs{0, 1, -1, 4, -4, 777, -777}; std::vector<int64_t> inputs{0, 1, -1, 4, -4, 777, -777};
for (auto i : inputs) { for (auto i : inputs) {
for (auto j : inputs) { for (auto j : inputs) {
auto a = c10::make_intrusive<TestSymIntNodeImpl>(i)->toSymInt(); auto a = c10::SymInt(
auto b = c10::make_intrusive<TestSymIntNodeImpl>(j)->toSymInt(); static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(i)));
auto b = c10::SymInt(
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(j)));
ASSERT_EQ(get(a + b), i + j); ASSERT_EQ(get(a + b), i + j);
ASSERT_EQ(get(a - b), i - j); ASSERT_EQ(get(a - b), i - j);
ASSERT_EQ(get(a * b), i * j); ASSERT_EQ(get(a * b), i * j);

View File

@ -12,8 +12,9 @@ import itertools
import io import io
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
from torch import SymInt
aten = torch.ops.aten aten = torch.ops.aten
@ -116,9 +117,6 @@ def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg) sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset) return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
def create_symint(shape_env, i): def create_symint(shape_env, i):
return shape_env.create_symintnode(shape_env.create_symbol(i)) return shape_env.create_symintnode(shape_env.create_symbol(i))
@ -156,8 +154,8 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], PySymInt)) self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5) self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4) self.assertTrue(x.shape[1] == 4)
@ -165,17 +163,17 @@ class TestPySymInt(TestCase):
self.assertTrue(x.size()[0], 5) self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4) self.assertTrue(x.size()[1], 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(x.size()[2] == 3) self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5) self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4) self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3) self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) self.assertTrue(isinstance(x.size(2), SymInt))
offset = create_symint(shape_env, 2) offset = create_symint(shape_env, 2)
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset) y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS)) self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 2) self.assertTrue(y.storage_offset() == 2)
offset = 2 offset = 2
@ -267,7 +265,7 @@ class TestPySymInt(TestCase):
def test_stride(self): def test_stride(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS) self.assertIsInstance(x.stride()[0], SymInt)
@skipIfNoSympy @skipIfNoSympy
def test_size_expressions(self): def test_size_expressions(self):
@ -290,7 +288,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env) x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0]) r = sym_float(x.shape[0])
self.assertTrue(isinstance(r, torch.SymFloatNode)) self.assertIsInstance(r, torch.SymFloat, msg=type(r))
@skipIfNoSympy @skipIfNoSympy
def test_aten_ops(self): def test_aten_ops(self):
@ -320,13 +318,13 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta') r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) self.assertIsInstance(r.shape[0], SymInt)
@skipIfNoSympy @skipIfNoSympy
def test_guard_int(self): def test_guard_int(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
self.assertEqual(a0.guard_int(), 2) self.assertEqual(guard_int(a0), 2)
self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)") self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)")
@skipIfNoSympy @skipIfNoSympy
@ -347,7 +345,9 @@ class TestPySymInt(TestCase):
assert func == torch.ops.aten.add.Tensor assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered nonlocal sym_int_encountered
sym_int_encountered = kwargs["alpha"] is a0 # WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0 kwargs["alpha"] = 0
return func(*args) return func(*args)

View File

@ -1,391 +0,0 @@
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: jit"]
from torch._C import _disabled_torch_function_impl
import torch.fx
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
import unittest
import torch
import operator
import itertools
import io
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
try:
import sympy
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
def create_symint(shape_env, i):
return shape_env.create_symintnode(shape_env.create_symbol(i))
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
@skipIfNoSympy
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
@skipIfNoSympy
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
@skipIfNoSympy
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], PySymInt))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
offset = create_symint(shape_env, 2)
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
self.assertTrue(y.storage_offset() == 2)
offset = 2
z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(z.storage_offset(), int))
self.assertTrue(z.storage_offset() == 2)
@skipIfNoSympy
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
@skipIfNoSympy
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
@skipIfNoSympy
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
@skipIfNoSympy
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op = shape_env.guards[0][0]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
@skipIfNoSympy
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0])
self.assertTrue(isinstance(r, torch.SymFloatNode))
@skipIfNoSympy
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
@skipIfNoSympy
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(a0.guard_int(), 2)
self.assertEqual(str(shape_env.guards[0][0]), "s0")
self.assertEqual(shape_env.guards[0][1], 2)
@skipIfNoSympy
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
@skipIfNoSympy
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
sym_int_encountered = kwargs["alpha"] is a0
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
@skipIfNoSympy
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
def test_print_readable_with_symints(self, mock_stdout):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
fx_g.print_readable()
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
# No stacktrace found for following nodes
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
if __name__ == '__main__':
run_tests()

View File

@ -875,8 +875,7 @@ def forward(self, a_1):
self.assertExpectedInline(r, """\ self.assertExpectedInline(r, """\
def forward(self, a_1): def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0) sym_size = torch.ops.aten.sym_size(a_1, 0)
sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None pow_1 = sym_size ** 0.5; sym_size = None
pow_1 = sym_float ** 0.5; sym_float = None
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
return div""") return div""")
@ -949,7 +948,7 @@ def forward(self, a_1):
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
meta_d = _get_node(fx_g, lambda x: x.target == operator.add) meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr) self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].node.expr)
def test_metadata_fresh(self): def test_metadata_fresh(self):
def f(x): def f(x):

View File

@ -207,8 +207,8 @@ class TestPublicBindings(TestCase):
"StreamObjType", "StreamObjType",
"StringType", "StringType",
"SUM", "SUM",
"SymFloatNode", "SymFloat",
"SymIntNode", "SymInt",
"TensorType", "TensorType",
"ThroughputBenchmark", "ThroughputBenchmark",
"TracingState", "TracingState",

View File

@ -291,7 +291,7 @@ PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i : c10::irange(prop.size())) { for (auto i : c10::irange(prop.size())) {
auto si = prop[i]; auto si = prop[i];
if (si.is_symbolic()) { if (si.is_symbolic()) {
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); auto py_symint = py::cast(si).release().ptr();
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
} else { } else {
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked())); PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked()));
@ -313,7 +313,7 @@ return PyLong_FromUnsignedLong((int64_t) prop);
""" """
GETTER_BODY_SYMINT = """\ GETTER_BODY_SYMINT = """\
return prop.is_symbolic() ? py::cast(prop.toSymIntNodeImpl()).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked()); return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
""" """
GETTER_BODY_DOUBLE = """\ GETTER_BODY_DOUBLE = """\

View File

@ -5,7 +5,7 @@
#include <Python.h> #include <Python.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include "torch/csrc/autograd/generated/Functions.h" #include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/python_cpp_function.h" #include "torch/csrc/autograd/python_cpp_function.h"
#include <torch/csrc/autograd/python_variable.h> #include <torch/csrc/autograd/python_variable.h>

View File

@ -240,12 +240,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
if (jit::tracer::isTracing()) { if (jit::tracer::isTracing()) {
return wrap(jit::tracer::getNumelOf(self_)); return wrap(jit::tracer::getNumelOf(self_));
} else { } else {
auto si = self_.sym_numel(); return py::cast(self_.sym_numel()).release().ptr();
if (si.is_symbolic()) {
return py::cast(si.toSymIntNodeImpl()).release().ptr();
} else {
return THPUtils_packInt64(si.as_int_unchecked());
}
} }
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }

View File

@ -722,7 +722,7 @@ def gen_pyi(
binop += "_" binop += "_"
out_suffix = "" out_suffix = ""
unsorted_tensor_method_hints[binop].append( unsorted_tensor_method_hints[binop].append(
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode]{})" "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})"
" -> Tensor: ...".format(binop, out_suffix) " -> Tensor: ...".format(binop, out_suffix)
) )
for binop in ["add", "sub"]: for binop in ["add", "sub"]:
@ -732,7 +732,7 @@ def gen_pyi(
binop += "_" binop += "_"
out_suffix = "" out_suffix = ""
unsorted_tensor_method_hints[binop].append( unsorted_tensor_method_hints[binop].append(
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode], " "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], "
"*, alpha: Optional[Number]=1{})" "*, alpha: Optional[Number]=1{})"
" -> Tensor: ...".format(binop, out_suffix) " -> Tensor: ...".format(binop, out_suffix)
) )

View File

@ -169,20 +169,6 @@ class Future(object):
def _jit_set_num_profiled_runs(num: _size) -> _size: ... def _jit_set_num_profiled_runs(num: _size) -> _size: ...
class SymIntNode(object):
def get_pyobj(self) -> Any: ...
@staticmethod
def new_symint(obj) -> SymIntNode: ...
class SymFloatNode(object):
def get_pyobj(self) -> Any: ...
@staticmethod
def new_symfloat(obj) -> SymFloatNode: ...
def __ceil__(self) -> SymIntNode: ...
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h # Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
class MobileOptimizerType: class MobileOptimizerType:
... ...

View File

@ -47,7 +47,7 @@ __all__ = [
'is_deterministic_algorithms_warn_only_enabled', 'is_deterministic_algorithms_warn_only_enabled',
'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
] ]
################################################################################ ################################################################################
@ -196,6 +196,67 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
import torch._C as _C import torch._C as _C
class SymInt:
"""
Like an int (including magic methods), but redirects all operations on the
wrapped node. This is used in particular to symbolically record operations
in the symbolic shape workflow.
"""
def __init__(self, node):
from torch.fx.experimental.symbolic_shapes import SymNode
assert isinstance(node, SymNode)
# This field MUST be named node; C++ binding code assumes that this
# class has a field named node that stores SymNode
self.node = node
# Magic methods installed later
def __bool__(self):
return self.node.bool_()
def __int__(self):
return self.node.int_()
def __sym_float__(self):
return SymFloat(self.node.sym_float())
def __repr__(self):
return self.node.str()
# For BC; direct access of node is OK too
def get_pyobj(self):
return self.node
class SymFloat:
"""
Like an float (including magic methods), but redirects all operations on the
wrapped node. This is used in particular to symbolically record operations
in the symbolic shape workflow.
"""
def __init__(self, node):
from torch.fx.experimental.symbolic_shapes import SymNode
assert isinstance(node, SymNode)
# This field MUST be named node; C++ binding code assumes that this
# class has a field named node that stores SymNode
self.node = node
# Magic methods installed later
def __bool__(self):
return self.node.bool_()
def __sym_int__(self):
return SymInt(self.node.sym_int())
def __repr__(self):
return self.node.str()
# For BC; direct access of node is OK too
def get_pyobj(self):
return self.node
# Check to see if we can load C extensions, and if not provide some guidance # Check to see if we can load C extensions, and if not provide some guidance
# on what the problem might be. # on what the problem might be.
try: try:
@ -941,7 +1002,6 @@ from ._linalg_utils import ( # type: ignore[misc]
lstsq, lstsq,
) )
def _register_device_module(device_type, module): def _register_device_module(device_type, module):
r"""Register an external runtime module of the specific :attr:`device_type` r"""Register an external runtime module of the specific :attr:`device_type`
supported by torch. supported by torch.
@ -971,3 +1031,6 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
import torch.cuda._sanitizer as csan import torch.cuda._sanitizer as csan
csan.enable_cuda_sanitizer() csan.enable_cuda_sanitizer()
# Populate magic methods on SymInt and SymFloat
import torch.fx.experimental.symbolic_shapes

View File

@ -337,7 +337,7 @@ class TensorVariable(VariableTracker):
from . import UserDefinedObjectVariable from . import UserDefinedObjectVariable
return UserDefinedObjectVariable(example_value) return UserDefinedObjectVariable(example_value)
elif isinstance(example_value, torch.SymIntNode): elif isinstance(example_value, torch.SymInt):
proxy.node.meta["example_value"] = example_value proxy.node.meta["example_value"] = example_value
return cls(proxy, **options) return cls(proxy, **options)
else: else:

View File

@ -40,11 +40,9 @@ class GraphLowering(torch.fx.Interpreter):
else: else:
size, stride = self._shape_env.create_symbolic_sizes_strides(ex) size, stride = self._shape_env.create_symbolic_sizes_strides(ex)
size = [ size = [i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in size]
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in size
]
stride = [ stride = [
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in stride i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in stride
] ]
return size, stride return size, stride

View File

@ -392,8 +392,8 @@ def _elementwise_meta(
# Number case # Number case
# NOTE: this case is not currently exercised # NOTE: this case is not currently exercised
# TODO: fix number type promotion (bool, complex->float) # TODO: fix number type promotion (bool, complex->float)
assert not isinstance(number, torch.SymIntNode), "NYI" assert not isinstance(number, torch.SymInt), "NYI"
assert not isinstance(number, torch.SymFloatNode), "NYI" assert not isinstance(number, torch.SymFloat), "NYI"
return TensorMeta(number) return TensorMeta(number)
@ -932,7 +932,7 @@ bitwise_xor = _make_elementwise_binary_prim(
# div prim performs truncation division on integer inputs # div prim performs truncation division on integer inputs
# and true division for floating and complex inputs # and true division for floating and complex inputs
def _div_aten(a, b): def _div_aten(a, b):
is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or ( is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype) isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
) )

View File

@ -42,18 +42,18 @@ ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
StrideType = Union[List[int], Tuple[int, ...]] StrideType = Union[List[int], Tuple[int, ...]]
DimsType = Union[int, List[int], Tuple[int, ...]] DimsType = Union[int, List[int], Tuple[int, ...]]
DimsSequenceType = Union[List[int], Tuple[int, ...]] DimsSequenceType = Union[List[int], Tuple[int, ...]]
# TODO: Type[torch.SymIntNode], Type[torch.SymFloatNode] # TODO: Type[torch.SymInt], Type[torch.SymFloat]
NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
# TODO: This needs a lot more type annotations # TODO: This needs a lot more type annotations
# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode] # NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
NumberType = Union[bool, int, float, complex] NumberType = Union[bool, int, float, complex]
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode) Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
# I don't call it Integral because numbers.Integral includes bool, but IntLike # I don't call it Integral because numbers.Integral includes bool, but IntLike
# does not # does not
Dim = int Dim = int
IntLike = (int, torch.SymIntNode) IntLike = (int, torch.SymInt)
FloatLike = (float, torch.SymFloatNode) FloatLike = (float, torch.SymFloat)
IntWithoutSymInt = int IntWithoutSymInt = int
FloatWithoutSymFloat = float FloatWithoutSymFloat = float
DeviceLikeType = Union[str, torch.device] DeviceLikeType = Union[str, torch.device]
@ -1113,10 +1113,10 @@ class RETURN_TYPE(Enum):
# TODO: when NumberType contains the sym types, can simplify this # TODO: when NumberType contains the sym types, can simplify this
def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type: def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
if isinstance(x, torch.SymIntNode): if isinstance(x, torch.SymInt):
return int return int
elif isinstance(x, torch.SymFloatNode): elif isinstance(x, torch.SymFloat):
return float return float
else: else:
return type(x) return type(x)

View File

@ -656,7 +656,7 @@ class FakeTensorMode(TorchDispatchMode):
return args[0].fake_device return args[0].fake_device
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs)) flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs)) flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
has_symbolic_sizes = ( has_symbolic_sizes = (
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors]) any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors])
or len(flat_symints) > 0 or len(flat_symints) > 0

View File

@ -59,7 +59,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
TORCH_CHECK( TORCH_CHECK(
!torch::jit::tracer::isTracing(), !torch::jit::tracer::isTracing(),
"JIT Tracing of SymInts isn't supported"); "JIT Tracing of SymInts isn't supported");
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); auto py_symint = py::cast(si).release().ptr();
if (!py_symint) if (!py_symint)
throw python_error(); throw python_error();
PyTuple_SET_ITEM(ret.get(), i, py_symint); PyTuple_SET_ITEM(ret.get(), i, py_symint);
@ -98,7 +98,7 @@ static PyObject* THPSize_pynew(
if (THPUtils_checkLong(item)) { if (THPUtils_checkLong(item)) {
continue; continue;
} }
if (torch::is_symint_node(item)) { if (torch::is_symint(item)) {
continue; continue;
} }
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) { if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
@ -135,7 +135,7 @@ static PyObject* THPSize_repr(THPSize* self) {
auto item = PyTuple_GET_ITEM(self, i); auto item = PyTuple_GET_ITEM(self, i);
auto ih = py::handle(item); auto ih = py::handle(item);
repr += torch::is_symint_node(ih) repr += torch::is_symint(ih)
? std::string(py::str(ih)) ? std::string(py::str(ih))
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i))); : std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
} }

View File

@ -2646,9 +2646,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
"Cannot call numel on a tensor with symbolic shapes/strides"); "Cannot call numel on a tensor with symbolic shapes/strides");
return self->sym_numel_default(); return self->sym_numel_default();
} }
return torch::is_symint_node(out) return torch::is_symint(out) ? out.cast<c10::SymInt>()
? out.cast<c10::SymIntNodeImpl*>()->toSymInt() : c10::SymInt{py::cast<int64_t>(out)};
: c10::SymInt{py::cast<int64_t>(out)};
} }
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
@ -2669,9 +2668,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
if (out.is(py::none())) { if (out.is(py::none())) {
return self->sym_storage_offset_default(); return self->sym_storage_offset_default();
} }
return torch::is_symint_node(out) return torch::is_symint(out) ? out.cast<c10::SymInt>()
? out.cast<c10::SymIntNodeImpl*>()->toSymInt() : c10::SymInt{py::cast<int64_t>(out)};
: c10::SymInt{py::cast<int64_t>(out)};
} }
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
@ -2701,9 +2699,8 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
py::list symints; py::list symints;
for (auto it = out.begin(); it != out.end(); it++) { for (auto it = out.begin(); it != out.end(); it++) {
auto elm = *it; auto elm = *it;
auto si = torch::is_symint_node(elm) auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>()
? elm.cast<c10::SymIntNodeImpl*>()->toSymInt() : c10::SymInt{py::cast<int64_t>(elm)};
: c10::SymInt{py::cast<int64_t>(elm)};
symints.append(si.as_int_unchecked()); symints.append(si.as_int_unchecked());
} }

View File

@ -13,7 +13,7 @@
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
#include <torch/csrc/jit/codegen/onednn/interface.h> #include <torch/csrc/jit/codegen/onednn/interface.h>
#endif #endif
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <torch/csrc/jit/frontend/ir_emitter.h> #include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h> #include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/irparser.h> #include <torch/csrc/jit/ir/irparser.h>
@ -99,7 +99,6 @@
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h> #include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
#include <torch/csrc/utils/cpp_stacktraces.h> #include <torch/csrc/utils/cpp_stacktraces.h>
#include <c10/core/SymFloat.h>
#include <c10/macros/Export.h> #include <c10/macros/Export.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/util/signal_handler.h> #include <c10/util/signal_handler.h>
@ -126,249 +125,11 @@ using c10::Argument;
using c10::FunctionSchema; using c10::FunctionSchema;
using c10::SchemaArgType; using c10::SchemaArgType;
using c10::SchemaArgument; using c10::SchemaArgument;
using c10::SymFloat; using c10::SymNode;
using c10::SymFloatNode;
using c10::SymIntNode;
using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter; using caffe2::serialize::PyTorchStreamWriter;
using torch::utils::SchemaInfo; using torch::utils::SchemaInfo;
static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
return torch::is_symint_node(b) ? b.cast<c10::SymIntNode>()
: a->wrap(b.cast<int64_t>());
}
static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) {
if (torch::is_symfloat_node(b)) {
return b.cast<c10::SymFloatNode>();
} else if (torch::is_symint_node(b)) {
return b.cast<c10::SymIntNode>()->sym_float();
} else {
return a->wrap(b.cast<double>());
}
}
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
public:
PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
virtual SymIntNode clone() override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("clone")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode wrap(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
}
virtual int64_t guard_int(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
}
virtual int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__int__")().cast<int64_t>();
}
SymFloatNode sym_float() override;
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
}
virtual SymIntNode dispatch_common_(
const char* fname,
const SymIntNode& other) {
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode dispatch_common_(const char* fname) {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode add(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode sub(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode mul(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymFloatNode truediv(const SymIntNode& other) override;
virtual SymIntNode floordiv(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode mod(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode eq(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode gt(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode lt(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode le(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode ge(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode min(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode max(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
virtual SymIntNode neg() override {
return dispatch_common_(__FUNCTION__);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
public:
PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
virtual SymFloatNode wrap(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
}
SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) {
auto pother = dynamic_cast<PythonSymFloatNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
SymFloatNode add(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode sub(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode mul(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode truediv(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode pow(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode eq(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode gt(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode lt(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode le(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode ge(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymIntNode ceil() override;
SymIntNode floor() override;
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) {
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("truediv")(pother->getPyObj());
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
SymFloatNode PythonSymIntNodeImpl::sym_float() {
py::gil_scoped_acquire acquire;
return c10::make_intrusive<PythonSymFloatNodeImpl>(
getPyObj().attr("__sym_float__")());
}
SymIntNode PythonSymFloatNodeImpl::ceil() {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("ceil")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
SymIntNode PythonSymFloatNodeImpl::floor() {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("floor")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
namespace { namespace {
using autograd::variable_list; using autograd::variable_list;
@ -1381,276 +1142,41 @@ void initJITBindings(PyObject* module) {
} }
}); });
auto symint_class = // NB: This isn't actually used for regular PyTorch symbolic tracing;
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode") // XLA is what needs this
.def_static( #define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
"new_symint", #define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); })
[](py::object obj) -> c10::SymIntNode { #define SYMNODE_BINARY(n) \
return c10::make_intrusive<PythonSymIntNodeImpl>(obj); .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
}) auto symnode_class =
.def( py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
"get_pyobj", // These DO NOT install magic methods; the SymInt/SymFloat wrapper in
[](c10::SymIntNode a) -> py::object { // Python is responsible for this
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) { SYMNODE_UNARY(clone)
return py::reinterpret_borrow<py::object>(psn->getPyObj()); // Named these for consistency with inner python class, but maybe
} // should change the python side
return py::none(); SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_)
}) SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2(
.def( __sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub)
"__add__", SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow)
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode { SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(
auto snb = toSymIntNode(a, b); eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt)
return a->add(snb); SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min)
}) SYMNODE_BINARY(max) SYMNODE_UNARY(ceil)
.def( SYMNODE_UNARY(floor) SYMNODE_UNARY(neg)
"__radd__", // Intentionally don't set file line, as the
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode { // Python backtrace matters more here
auto snb = toSymIntNode(a, b); .def(
return snb->add(a); "guard_int",
}) [](c10::SymNode a) {
.def( return a->guard_int(nullptr, 0);
"__sub__", })
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode { .def(
auto snb = toSymIntNode(a, b); "__str__",
return a->sub(snb); [](c10::SymNode a) { return a->str(); })
}) .def("__repr__", [](c10::SymNode a) {
.def( return a->str();
"__rsub__", });
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->sub(a);
})
.def(
"__mul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->mul(a);
})
.def(
"__truediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymIntNode(a, b);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymIntNode(a, b);
return snb->truediv(a);
})
.def(
"__floordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->floordiv(snb);
})
.def(
"__rfloordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->floordiv(a);
})
.def(
"__mod__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mod(snb);
})
.def(
"__rmod__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->mod(a);
})
.def(
"__pow__",
[](c10::SymIntNode a, py::object b) -> py::object {
if (PyFloat_Check(b.ptr())) {
auto float_a = a->sym_float();
return py::cast(
float_a->pow(float_a->wrap(py::cast<double>(b))));
}
// TODO: integer pow
return py::reinterpret_borrow<py::object>(Py_NotImplemented);
})
// TODO: rpow
.def(
"__eq__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](c10::SymIntNode a, py::object b) {
auto snb = toSymIntNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->ge(snb);
})
.def(
"__ceil__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__neg__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); })
.def(
"__min__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->min(snb);
})
.def(
"__max__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->max(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
// Intentionally don't set file line, as the Python backtrace matters
// more here
.def(
"guard_int",
[](c10::SymIntNode a) { return a->guard_int(nullptr, 0); })
.def(
"__sym_float__",
[](c10::SymIntNode a) {
// TODO: remove dynamic cast when sym_float is in base class
auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get());
TORCH_INTERNAL_ASSERT(psn);
return psn->sym_float();
})
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
py::class_<c10::SymFloatNodeImpl, c10::SymFloatNode>(m, "SymFloatNode")
.def_static(
"new_symfloat",
[](py::object obj) -> c10::SymFloatNode {
return c10::make_intrusive<PythonSymFloatNodeImpl>(obj);
})
.def(
"__add__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->add(snb);
})
.def(
"__radd__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->add(a);
})
.def(
"__sub__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->sub(snb);
})
.def(
"__mul__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->mul(a);
})
.def(
"__truediv__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->truediv(a);
})
.def(
"__eq__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](c10::SymFloatNode a, py::object b) {
auto snb = toSymFloatNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->ge(snb);
})
.def(
"__pow__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->pow(snb);
})
.def(
"__rpow__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->pow(a);
})
.def(
"__ceil__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__floor__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
.def(
"get_pyobj",
[](c10::SymFloatNode a) -> py::object {
if (auto* psn = dynamic_cast<PythonSymFloatNodeImpl*>(a.get())) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();
})
.def("__str__", [](c10::SymFloatNode a) { return a->str(); });
// NOLINTNEXTLINE(bugprone-unused-raii) // NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec") py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")

View File

@ -80,10 +80,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr())); scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
} else if (THPUtils_checkDouble(obj.ptr())) { } else if (THPUtils_checkDouble(obj.ptr())) {
scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr())); scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
} else if (torch::is_symint_node(py::handle(obj))) { } else if (torch::is_symint(py::handle(obj))) {
save_symint = true; save_symint = true;
scalar = at::Scalar(7777777); scalar = at::Scalar(7777777);
} else if (torch::is_symfloat_node(py::handle(obj))) { } else if (torch::is_symfloat(py::handle(obj))) {
save_symint = true; save_symint = true;
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN()); scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
} else { } else {
@ -161,12 +161,12 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
return py::cast<int64_t>(obj); return py::cast<int64_t>(obj);
} }
case TypeKind::SymIntType: case TypeKind::SymIntType:
if (torch::is_symint_node(obj.ptr())) { if (torch::is_symint(obj.ptr())) {
return py::cast<c10::SymInt>(obj); return py::cast<c10::SymInt>(obj);
} }
return py::cast<int64_t>(obj); return py::cast<int64_t>(obj);
case TypeKind::SymFloatType: case TypeKind::SymFloatType:
if (torch::is_symfloat_node(obj.ptr())) { if (torch::is_symfloat(obj.ptr())) {
return py::cast<c10::SymFloat>(obj); return py::cast<c10::SymFloat>(obj);
} }
return py::cast<double>(obj); return py::cast<double>(obj);
@ -253,7 +253,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
bool is_symbolic = false; bool is_symbolic = false;
for (auto it = obj.begin(); it != obj.end(); it++) { for (auto it = obj.begin(); it != obj.end(); it++) {
auto elm = *it; auto elm = *it;
if (torch::is_symint_node(elm)) { if (torch::is_symint(elm)) {
is_symbolic = true; is_symbolic = true;
break; break;
} }
@ -269,7 +269,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
for (auto it = obj.begin(); it != obj.end(); it++) { for (auto it = obj.begin(); it != obj.end(); it++) {
auto elm = *it; auto elm = *it;
// TODO: what about SymInt conversion to SymFloat? // TODO: what about SymInt conversion to SymFloat?
if (torch::is_symfloat_node(elm)) { if (torch::is_symfloat(elm)) {
is_symbolic = true; is_symbolic = true;
break; break;
} }
@ -442,9 +442,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
} else if (PyComplex_CheckExact(obj.ptr())) { } else if (PyComplex_CheckExact(obj.ptr())) {
auto c_obj = py::cast<std::complex<double>>(obj.ptr()); auto c_obj = py::cast<std::complex<double>>(obj.ptr());
return static_cast<c10::complex<double>>(c_obj); return static_cast<c10::complex<double>>(c_obj);
} else if (torch::is_symint_node(obj)) { } else if (torch::is_symint(obj)) {
return py::cast<c10::SymInt>(obj); return py::cast<c10::SymInt>(obj);
} else if (torch::is_symfloat_node(obj)) { } else if (torch::is_symfloat(obj)) {
return py::cast<c10::SymFloat>(obj); return py::cast<c10::SymFloat>(obj);
} else { } else {
throw py::cast_error( throw py::cast_error(

View File

@ -136,10 +136,10 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
inline Value GetSymIntValue(c10::SymInt a) { inline Value GetSymIntValue(c10::SymInt a) {
return Value( return Value(
a.is_symbolic() ? dynamic_cast<torch::lazy::SymIntNodeImpl*>( a.is_symbolic()
a.toSymIntNodeImpl().get()) ? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get())
->node_ ->node_
: MakeScalar(a.as_int_unchecked(), at::kLong), : MakeScalar(a.as_int_unchecked(), at::kLong),
0); 0);
} }

View File

@ -451,11 +451,11 @@ std::vector<Shape> compute_shape_expand(
std::vector<int64_t> target_size(_sizes.size()); std::vector<int64_t> target_size(_sizes.size());
for (const auto idx : c10::irange(_sizes.size())) { for (const auto idx : c10::irange(_sizes.size())) {
if (_sizes[idx].is_symbolic()) { if (_sizes[idx].is_symbolic()) {
c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl(); c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl();
auto* lazySymIntNode = auto* lazySymNode =
dynamic_cast<torch::lazy::SymIntNodeImpl*>(symbolicIntNode.get()); dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get());
TORCH_INTERNAL_ASSERT(lazySymIntNode); TORCH_INTERNAL_ASSERT(lazySymNode);
auto size_node = lazySymIntNode->node_; auto size_node = lazySymNode->node_;
auto static_value = auto static_value =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node) std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
->getStaticValue(); ->getStaticValue();

View File

@ -4,7 +4,7 @@
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h> #include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h> #include <c10/core/SymIntArrayRef.h>
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h> #include <c10/macros/Export.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
#include <torch/csrc/lazy/backend/backend_data.h> #include <torch/csrc/lazy/backend/backend_data.h>

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
#include <torch/csrc/lazy/backend/backend_data.h> #include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h> #include <torch/csrc/lazy/backend/backend_device.h>
@ -10,12 +10,9 @@
namespace torch { namespace torch {
namespace lazy { namespace lazy {
class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl { class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
public: public:
SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
c10::SymIntNode add(const c10::SymIntNode& other) override {
TORCH_CHECK(false, "NYI");
}
NodePtr node_; NodePtr node_;
}; };

View File

@ -685,7 +685,7 @@ static bool is_int_list(
// NB: do NOT check that the later arguments are ints, as this is // NB: do NOT check that the later arguments are ints, as this is
// BC-breaking for FX // BC-breaking for FX
for (int i = 1; i < len; i++) { for (int i = 1; i < len; i++) {
if (torch::is_symint_node( if (torch::is_symint(
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) { py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
if (failed_idx != nullptr) { if (failed_idx != nullptr) {
*failed_idx = i; *failed_idx = i;
@ -716,9 +716,9 @@ static bool is_int_list(
static bool is_int_or_symint(PyObject* obj) { static bool is_int_or_symint(PyObject* obj) {
// THPUtils_checkIndex may call __index__ or __int__ // THPUtils_checkIndex may call __index__ or __int__
// which may have side effects if obj is a symint node // which may have side effects if obj is a symint node
// so we do `is_symint_node` check first // so we do `is_symint` check first
// TODO: maybe we should be using checkLong here? // TODO: maybe we should be using checkLong here?
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj); return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
} }
static bool is_int_or_symint_list( static bool is_int_or_symint_list(
@ -1570,13 +1570,13 @@ at::Tensor PythonArgs::tensor_slow(int i) {
// NB: we DO NOT put symbolic ints/floats into the Scalar itself, // NB: we DO NOT put symbolic ints/floats into the Scalar itself,
// because although Scalar supports SymInt/SymFloat, the subsequent // because although Scalar supports SymInt/SymFloat, the subsequent
// conversion to Tensor does not. Instead, do it out of band. // conversion to Tensor does not. Instead, do it out of band.
} else if (torch::is_symint_node(py::handle(obj))) { } else if (torch::is_symint(py::handle(obj))) {
save_symint = true; save_symint = true;
// This scalar value doesn't matter, it shouldn't ever actually // This scalar value doesn't matter, it shouldn't ever actually
// get read out. Make it a big and weird looking number to help // get read out. Make it a big and weird looking number to help
// people figure out if there's aproblem. // people figure out if there's aproblem.
scalar = at::Scalar(7777777); scalar = at::Scalar(7777777);
} else if (torch::is_symfloat_node(py::handle(obj))) { } else if (torch::is_symfloat(py::handle(obj))) {
save_symint = true; save_symint = true;
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN()); scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
} else { } else {
@ -1633,11 +1633,11 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
return at::Scalar(THPUtils_unpackComplexDouble(arg)); return at::Scalar(THPUtils_unpackComplexDouble(arg));
} }
if (torch::is_symint_node(arg)) { if (torch::is_symint(arg)) {
return at::Scalar(py::cast<c10::SymInt>(arg)); return at::Scalar(py::cast<c10::SymInt>(arg));
} }
if (torch::is_symfloat_node(arg)) { if (torch::is_symfloat(arg)) {
return at::Scalar(py::cast<c10::SymFloat>(arg)); return at::Scalar(py::cast<c10::SymFloat>(arg));
} }

View File

@ -61,6 +61,7 @@
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h> #include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h> #include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/python_symnode.h>
#include <torch/csrc/utils/six.h> #include <torch/csrc/utils/six.h>
#include <ATen/PythonTorchFunctionTLS.h> #include <ATen/PythonTorchFunctionTLS.h>
@ -69,7 +70,7 @@
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/core/SymFloat.h> #include <c10/core/SymFloat.h>
#include <c10/core/SymIntNodeImpl.h> #include <c10/core/SymNodeImpl.h>
#include <array> #include <array>
#include <cstddef> #include <cstddef>
@ -78,30 +79,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace torch {
inline bool is_symint_node(py::handle obj) {
auto static tp_symn = py::type::of<c10::SymIntNodeImpl>();
if (py::isinstance(obj, tp_symn)) {
TORCH_CHECK(
!jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!");
return true;
}
return false;
}
inline bool is_symfloat_node(py::handle obj) {
auto static tp_symn = py::type::of<c10::SymFloatNodeImpl>();
if (py::isinstance(obj, tp_symn)) {
TORCH_CHECK(
!jit::tracer::isTracing(), "JIT tracing of SymFloats isn't supported!");
return true;
}
return false;
}
} // namespace torch
namespace pybind11 { namespace pybind11 {
namespace detail { namespace detail {
template <> template <>
@ -109,8 +86,10 @@ struct type_caster<c10::SymInt> {
public: public:
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt")); PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
bool load(py::handle src, bool) { bool load(py::handle src, bool) {
if (torch::is_symint_node(src)) { if (torch::is_symint(src)) {
value = src.cast<c10::SymIntNodeImpl*>()->toSymInt(); value = c10::SymInt(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
src.attr("node"))));
return true; return true;
} }
@ -126,8 +105,15 @@ struct type_caster<c10::SymInt> {
c10::SymInt si, c10::SymInt si,
return_value_policy /* policy */, return_value_policy /* policy */,
handle /* parent */) { handle /* parent */) {
return si.is_symbolic() ? py::cast(si.toSymIntNodeImpl()).release() if (si.is_symbolic()) {
: py::cast(si.expect_int()).release(); // TODO: generalize this to work with C++ backed class
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_int_unchecked()).release();
}
} }
}; };
@ -136,8 +122,10 @@ struct type_caster<c10::SymFloat> {
public: public:
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat")); PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
bool load(py::handle src, bool) { bool load(py::handle src, bool) {
if (torch::is_symfloat_node(src)) { if (torch::is_symfloat(src)) {
value = src.cast<c10::SymFloatNodeImpl*>()->toSymFloat(); value = c10::SymFloat(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
src.attr("node"))));
return true; return true;
} }
@ -153,8 +141,15 @@ struct type_caster<c10::SymFloat> {
c10::SymFloat si, c10::SymFloat si,
return_value_policy /* policy */, return_value_policy /* policy */,
handle /* parent */) { handle /* parent */) {
return si.is_symbolic() ? py::cast(si.toSymFloatNodeImpl()).release() if (si.is_symbolic()) {
: py::cast(si.expect_float()).release(); // TODO: generalize this to work with C++ backed class
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symfloat_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_float_unchecked()).release();
}
} }
}; };
} // namespace detail } // namespace detail
@ -167,8 +162,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) {
} }
#endif #endif
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) || return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
torch::is_symint_node(py::handle(obj)) || torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj));
torch::is_symfloat_node(py::handle(obj));
} }
namespace torch { namespace torch {
@ -574,7 +568,7 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
inline PyObject* toPyObject(c10::SymInt symint) { inline PyObject* toPyObject(c10::SymInt symint) {
if (symint.is_symbolic()) { if (symint.is_symbolic()) {
auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr(); auto r = py::cast(symint).release().ptr();
TORCH_INTERNAL_ASSERT(r); TORCH_INTERNAL_ASSERT(r);
return r; return r;
} else { } else {
@ -609,8 +603,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
size1, c10::SymInt(THPUtils_unpackIndex(args[i]))); size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
} }
if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) { if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
auto si = py::handle(args[i]).cast<c10::SymIntNodeImpl*>()->toSymInt(); auto si = py::handle(args[i]).cast<c10::SymInt>();
return std::vector<c10::SymInt>(size1, si); return std::vector<c10::SymInt>(size1, si);
} }
@ -652,9 +646,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
res.push_back(var.item<int64_t>()); res.push_back(var.item<int64_t>());
} else { } else {
try { try {
if (is_symint_node(py::handle(obj))) { if (is_symint(py::handle(obj))) {
res.push_back( res.push_back(py::handle(obj).cast<c10::SymInt>());
py::handle(obj).cast<c10::SymIntNodeImpl*>()->toSymInt());
} else { } else {
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj))); res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
} }

View File

@ -0,0 +1,19 @@
#include <torch/csrc/utils/python_symnode.h>
namespace torch {
py::handle get_symint_class() {
// NB: leak
static py::handle symint_class =
py::object(py::module::import("torch").attr("SymInt")).release();
return symint_class;
}
py::handle get_symfloat_class() {
// NB: leak
static py::handle symfloat_class =
py::object(py::module::import("torch").attr("SymFloat")).release();
return symfloat_class;
}
} // namespace torch

View File

@ -0,0 +1,182 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
TORCH_PYTHON_API py::handle get_symint_class();
TORCH_PYTHON_API py::handle get_symfloat_class();
// NB: These functions must not be called too early, otherwise torch not setup.
// Alternate design is to have torch "register" the object to us
inline bool is_symint(py::handle obj) {
return py::isinstance(obj, get_symint_class());
}
inline bool is_symfloat(py::handle obj) {
return py::isinstance(obj, get_symfloat_class());
}
namespace impl {
// This c10::SymNodeImpl simply backends to a Python object that
// implements the API. The Python object is the source of truth,
// this is just an adapter so C++ calls can get to the object.
class PythonSymNodeImpl : public c10::SymNodeImpl {
public:
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
c10::SymNode wrap_int(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_int")(num);
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode wrap_float(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_float")(num);
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("bool_")().is(py::handle(Py_True));
}
bool is_int() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_int")().is(py::handle(Py_True));
}
bool is_float() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_float")().is(py::handle(Py_True));
}
int64_t guard_int(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
}
double guard_float(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_float")(file, line).cast<double>();
}
int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("int_")().cast<int64_t>();
}
std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("str")().cast<std::string>();
}
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode dispatch_common_(const char* fname) {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)();
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode add(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode sub(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode mul(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode truediv(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode pow(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode floordiv(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode mod(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode eq(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode gt(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode lt(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode le(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode ge(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode min(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode max(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode floor() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode neg() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode clone() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode sym_int() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode sym_float() override {
return dispatch_common_(__FUNCTION__);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
} // namespace impl
} // namespace torch

View File

@ -21,8 +21,9 @@ import operator
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
from torch._subclasses import FakeTensor from torch._subclasses import FakeTensor
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
from torch.fx import Proxy from torch.fx import Proxy
from torch import SymInt, SymFloat
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"] __all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"]
aten = torch.ops.aten aten = torch.ops.aten
@ -55,27 +56,27 @@ def decompose(decomposition_table):
proxy_slot = object() proxy_slot = object()
no_default = object() no_default = object()
py_sym_types = ( py_sym_types = (SymInt, SymFloat)
PySymInt,
PySymFloat,
)
def is_sym_node(node): def is_sym_node(node):
assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta" assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
return "val" in node.meta and isinstance(node.meta['val'], py_sym_types) return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
def set_proxy_slot(obj, tracer, proxy): def set_proxy_slot(obj, tracer, proxy):
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload]
assert isinstance(d, weakref.WeakKeyDictionary) assert isinstance(d, weakref.WeakKeyDictionary)
d[tracer] = proxy d[tracer] = proxy
def has_proxy_slot(obj, tracer): def has_proxy_slot(obj, tracer):
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
return get_proxy_slot(obj, tracer, False, lambda _: True) return get_proxy_slot(obj, tracer, False, lambda _: True)
# the default argument is what to return if the slot is not set. # the default argument is what to return if the slot is not set.
# the transform argument is handy if you need to extract a subfield from # the transform argument is handy if you need to extract a subfield from
# the successfully looked up result (but NOT the default.) # the successfully looked up result (but NOT the default.)
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
d = obj.__dict__.get(proxy_slot) assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload]
if not d: if not d:
if default is no_default: if default is no_default:
raise KeyError(f"{obj} is not tracked with proxy for {tracer}") raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
@ -130,10 +131,8 @@ def track_tensor(tensor, proxy, *, constant, tracer):
def try_set_proxy_slot(outer_s, proxy_callable, *args): def try_set_proxy_slot(outer_s, proxy_callable, *args):
assert callable(proxy_callable) assert callable(proxy_callable)
if isinstance(outer_s, SymInt): if isinstance(outer_s, SymInt):
inner_s = outer_s.get_pyobj() inner_s = outer_s.node
assert isinstance(inner_s, py_sym_types) set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, outer_s, *args))
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args))
# The basic idea is that we need to associate each tensor/SymInt # The basic idea is that we need to associate each tensor/SymInt
# with a Proxy. How do we setup this association? We just store # with a Proxy. How do we setup this association? We just store
@ -198,7 +197,7 @@ class _ProxyTensor:
def fetch_sym_proxy(tracer): def fetch_sym_proxy(tracer):
def inner(e): def inner(e):
n = e.get_pyobj() n = e.node
if n.constant is not None: if n.constant is not None:
return n.constant return n.constant
else: else:
@ -400,8 +399,8 @@ class PythonKeyTracer(Tracer):
return self.create_node('get_attr', qualname, (), {}) return self.create_node('get_attr', qualname, (), {})
elif isinstance(a, (SymInt, SymFloat)): elif isinstance(a, (SymInt, SymFloat)):
assert a.get_pyobj().constant is not None assert a.node.constant is not None
return a.get_pyobj().constant return a.node.constant
return super().create_arg(a) return super().create_arg(a)
@ -432,7 +431,7 @@ def wrap_key(f, tensors, tracer):
) )
out = pytree.tree_map_only( out = pytree.tree_map_only(
(SymInt, SymFloat), (SymInt, SymFloat),
lambda t: get_proxy_slot(t.get_pyobj(), tracer)(), lambda t: get_proxy_slot(t.node, tracer)(),
out out
) )
return out return out
@ -479,10 +478,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
return out return out
SymInt = torch.SymIntNode
SymFloat = torch.SymFloatNode
class ProxySymDispatchMode(SymDispatchMode): class ProxySymDispatchMode(SymDispatchMode):
def __init__(self, tracer): def __init__(self, tracer):
super().__init__() super().__init__()
@ -501,10 +496,9 @@ class ProxySymDispatchMode(SymDispatchMode):
finally: finally:
self.enable_tracing = old self.enable_tracing = old
def _compute_proxy(self, func, args, out): def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat]):
n_args = tuple( n_args = tuple(
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a
if isinstance(a, py_sym_types) else a
for a in args for a in args
) )
@ -520,10 +514,11 @@ class ProxySymDispatchMode(SymDispatchMode):
return func(*args, **kwargs) return func(*args, **kwargs)
# Peephole optimize multiply by one # Peephole optimize multiply by one
# NB: be careful not to trigger guards here!
if func == operator.mul: if func == operator.mul:
if isinstance(args[1], (PySymInt, PySymFloat)) and args[1].constant == 1: if isinstance(args[1], int) and args[1] == 1:
return args[0] return args[0]
elif isinstance(args[0], (PySymInt, PySymFloat)) and args[0].constant == 1: elif isinstance(args[0], int) and args[0] == 1:
return args[1] return args[1]
# For speed, we assume there are no nested data structures # For speed, we assume there are no nested data structures
@ -535,7 +530,7 @@ class ProxySymDispatchMode(SymDispatchMode):
# Delays tracing out the proxies on this op until we actually need it # Delays tracing out the proxies on this op until we actually need it
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
set_proxy_slot(out, self.tracer, p_out_thunk) set_proxy_slot(out.node, self.tracer, p_out_thunk)
return out return out

View File

@ -10,6 +10,7 @@ import traceback
import collections import collections
import textwrap import textwrap
from torch._subclasses.meta_utils import MetaConverter from torch._subclasses.meta_utils import MetaConverter
from torch import SymInt, SymFloat
try: try:
import sympy # type: ignore[import] import sympy # type: ignore[import]
@ -21,8 +22,8 @@ except ImportError:
aten = torch.ops.aten # type: ignore[has-type] aten = torch.ops.aten # type: ignore[has-type]
__all__ = [ __all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv", "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv" "SymDispatchMode", "sym_float", "FloorDiv", "guard_int", "wrap_node"
] ]
SYM_FUNCTION_MODE = None SYM_FUNCTION_MODE = None
@ -88,32 +89,38 @@ def _handle_sym_dispatch(func, args, kwargs):
finally: finally:
SYM_FUNCTION_MODE = mode SYM_FUNCTION_MODE = mode
def guard_int(a):
if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace
assert isinstance(a, int)
return a
def sym_float(a): def sym_float(a):
if hasattr(a, '__sym_float__'): if isinstance(a, SymFloat):
return a.__sym_float__()
elif isinstance(a, torch._C.SymFloatNode):
return a return a
elif hasattr(a, '__sym_float__'):
return a.__sym_float__()
return float(a) return float(a)
def sym_int(a): def sym_int(a):
if hasattr(a, '__sym_int__'): if isinstance(a, SymInt):
return a.__sym_int__()
elif isinstance(a, torch._C.SymIntNode):
return a return a
elif hasattr(a, '__sym_int__'):
return a.__sym_int__()
return int(a) return int(a)
# TODO: An incomplete list # TODO: An incomplete list
# 1. Set variables to be equal when we do equality # 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction # 2. Specialize on 0/1 when we do subtraction
class PySymInt(object): class SymNode:
""" """
PySymInt objects are the primary "symbolic shape" objects that flow through This is a type erased SymInt/SymFloat which we use to do actual operations.
our program. They're what sit under FakeTensor, and contains our primary End users don't touch this. Magic methods are NOT defined on this object.
implementation of symbolic shapes.
""" """
def __init__(self, expr, shape_env, constant=None): def __init__(self, expr, shape_env, pytype, constant=None):
self._expr = expr self._expr = expr
self.shape_env = shape_env self.shape_env = shape_env
self.pytype = pytype
self.constant = constant self.constant = constant
@property @property
@ -121,23 +128,49 @@ class PySymInt(object):
self._update_expr() self._update_expr()
return self._expr return self._expr
def wrap(self, num):
return PySymInt(sympy.Integer(num), self.shape_env, constant=num)
def clone(self):
return PySymInt(self.expr, self.shape_env, constant=self.constant)
def _update_expr(self): def _update_expr(self):
self._expr = self.shape_env.replace(self._expr) self._expr = self.shape_env.replace(self._expr)
def __str__(self): def to_node(self, num):
if isinstance(num, (SymInt, SymFloat)):
return num.node
elif isinstance(num, int):
return self.wrap_int(num)
elif isinstance(num, float):
return self.wrap_float(num)
else:
# NotImplementedError is important so that Python tries the
# other magic method
raise NotImplementedError(type(num))
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def wrap_int(self, num):
assert isinstance(num, int)
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
def wrap_float(self, num):
assert isinstance(num, float)
return SymNode(sympy.Integer(num), self.shape_env, float, constant=num)
def clone(self):
return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant)
def str(self):
return f"{self.expr}" return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self): def __repr__(self):
return f"{self.expr}" return self.str()
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun. # Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
def __int__(self): def int_(self):
raise RuntimeError("Trying to extract a concrete int out of a symbolic int") raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
# You can manually trigger a guard with this function # You can manually trigger a guard with this function
@ -146,28 +179,35 @@ class PySymInt(object):
# guard occurred # guard occurred
return int(self.shape_env.evaluate_expr(self.expr)) return int(self.shape_env.evaluate_expr(self.expr))
def __sym_float__(self): def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
return float(self.shape_env.evaluate_expr(self.expr))
def sym_float(self):
if SYM_FUNCTION_MODE: if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(sym_float, (self,), {}) r = _handle_sym_dispatch(sym_float, (wrap_node(self),), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
# TODO: consider constant prop here # TODO: consider constant prop here
# TODO: wrapping the expr with sympy.Float doesn't seem to work, why # TODO: wrapping the expr with sympy.Float doesn't seem to work, why
# not? # not?
return PySymFloat(self.expr, self.shape_env) return SymNode(self.expr, self.shape_env, float)
def __bool__(self): def sym_int(self):
raise NotImplementedError("sym_int NYI")
"""
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(sym_int, (self,), {})
# TODO: consider constant prop here
# XXX: need to cast float to int in sympy; math.floor is wrong
# because negatives round to zero
return SymNode(self.expr, self.shape_env, int)
"""
def bool_(self):
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))) return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
class PySymFloat:
def __init__(self, expr, shape_env, constant=None):
self.expr = expr
self.shape_env = shape_env
self.constant = constant
def wrap(self, num):
return PySymFloat(sympy.Float(num), self.shape_env, constant=num)
def __str__(self):
return f"{self.expr}"
if HAS_SYMPY: if HAS_SYMPY:
class FloorDiv(sympy.Function): class FloorDiv(sympy.Function):
@ -238,32 +278,45 @@ unary_magic_methods = {
float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"} float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"}
def _make_magic(method, func, py_type): def wrap_node(x):
if not isinstance(x, SymNode):
return x
if x.constant is not None:
return x.constant
if x.pytype is int:
return SymInt(x)
elif x.pytype is float:
return SymFloat(x)
else:
raise AssertionError(f"unrecognized return type {x.pytype}")
def _make_node_magic(method, func):
func = lru_cache(256)(func) func = lru_cache(256)(func)
def magic_impl(self, other): def binary_magic_impl(self, other):
if method in ["min", "max"]: if method in ["min", "max"]:
op = getattr(builtins, method) op = getattr(builtins, method)
else: else:
op = getattr(operator, method) op = getattr(operator, method)
if SYM_FUNCTION_MODE: if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(op, (self, other), {}) r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
if isinstance(other, py_type): assert isinstance(r, (SymInt, SymFloat)), type(r)
other_expr = other.expr return r.node
else: assert isinstance(other, SymNode)
assert isinstance(other, sympy.Expr) other_expr = other.expr
other_expr = other
# TODO: consider constant prop here # TODO: consider constant prop here
expr = self.shape_env.replace(self.expr) expr = self.shape_env.replace(self.expr)
other_expr = self.shape_env.replace(other_expr) other_expr = self.shape_env.replace(other_expr)
out = func(expr, other_expr) out = func(expr, other_expr)
out = sympy.expand(out) out = sympy.expand(out)
if method in ["truediv"]: if method in ["truediv"]:
return PySymFloat(out, self.shape_env) pytype = float
else: else:
# TODO: relational operators actually technically return a pytype = self.pytype
# PySymBool, this is a type error
return py_type(out, self.shape_env) # TODO: relational operators actually technically return a
# PySymBool, this is a type error
return SymNode(out, self.shape_env, pytype)
def unary_magic_impl(self): def unary_magic_impl(self):
if SYM_FUNCTION_MODE: if SYM_FUNCTION_MODE:
@ -271,33 +324,55 @@ def _make_magic(method, func, py_type):
op = getattr(math, method) op = getattr(math, method)
else: else:
op = getattr(operator, method) op = getattr(operator, method)
return _handle_sym_dispatch(op, (self,), {}) r = _handle_sym_dispatch(op, (wrap_node(self),), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
# TODO: consider constant prop here # TODO: consider constant prop here
expr = self.shape_env.replace(self.expr) expr = self.shape_env.replace(self.expr)
out = func(expr) out = func(expr)
out = sympy.expand(out) out = sympy.expand(out)
if method in ["ceil", "floor"]: if method in ["ceil", "floor"]:
return PySymInt(out, self.shape_env) pytype = int
else: else:
return py_type(out, self.shape_env) pytype = self.pytype
return SymNode(out, self.shape_env, pytype)
# this should be wrapped transparently into torch.SymIntNode
if method in unary_magic_methods: if method in unary_magic_methods:
setattr(py_type, method, unary_magic_impl) setattr(SymNode, method, unary_magic_impl)
setattr(py_type, f"__{method}__", unary_magic_impl)
else: else:
setattr(py_type, method, magic_impl) setattr(SymNode, method, binary_magic_impl)
setattr(py_type, f"__{method}__", magic_impl)
if method in reflectable_magic_methods:
setattr(py_type, f"__r{method}__", magic_impl)
for method, func in magic_methods.items(): for method, func in magic_methods.items():
_make_magic(method, func, PySymInt) _make_node_magic(method, func)
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
def unary_magic_impl(self):
return wrap_node(getattr(self.node, method)())
def binary_magic_impl(self, other):
return wrap_node(getattr(self.node, method)(self.node.to_node(other)))
def rbinary_magic_impl(self, other):
return wrap_node(getattr(self.node.to_node(other), method)(self.node))
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
for method, func in magic_methods.items():
_make_user_magic(method, SymInt)
for method, func in magic_methods.items(): for method, func in magic_methods.items():
if method not in float_magic_methods: if method not in float_magic_methods:
continue continue
_make_magic(method, func, PySymFloat) _make_user_magic(method, SymFloat)
del method del method
del func del func
@ -390,9 +465,7 @@ class ShapeEnv(object):
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type] return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
def create_symintnode(self, expr: "sympy.Expr"): def create_symintnode(self, expr: "sympy.Expr"):
py_sym_int = PySymInt(expr, self) return SymInt(SymNode(expr, self, int))
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
return cpp_sym_int
def create_symbol(self, val: int) -> "sympy.Expr": def create_symbol(self, val: int) -> "sympy.Expr":
if not HAS_SYMPY: if not HAS_SYMPY:

View File

@ -498,7 +498,7 @@ class CodeGen(object):
if isinstance(meta_val, FakeTensor): if isinstance(meta_val, FakeTensor):
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}' maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
elif isinstance(meta_val, py_sym_types): elif isinstance(meta_val, py_sym_types):
maybe_type_annotation = f': Sym({meta_val.expr})' maybe_type_annotation = f': Sym({meta_val})'
elif isinstance(meta_val, TensorMetadata): elif isinstance(meta_val, TensorMetadata):
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}' maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'