Unify SymIntNode and SymFloatNode into SymNode (#87817)

This refactor was prompted by challenges handling mixed int/float
operations in C++.  A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/  This PR takes a different
approach.

The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode.  This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods.  This has a number of
knock on effects.

- We no longer have C++ classes to bind to Python.  Instead, we take an
  entirely new approach to our Python API, where we have a SymInt/SymFloat
  class defined entirely in Python, which hold a SymNode (which corresponds
  to the C++ SymNode).  However, SymNode is not pybind11-bound; instead,
  it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
  when it goes into C++.  This implies a userland rename.

  In principle, it is also possible for the canonical implementation of SymNode
  to be written in C++, and then bound to Python with pybind11 (we have
  this code, although it is commented out.)  However, I did not implement
  this as we currently have no C++ implementations of SymNode.

  Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
  code needs to know how to find these classes.  Currently, this is done
  just by manually importing torch and getting the attributes.

- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
  takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
  __torch_dispatch__ works.

Some miscellaneous improvements:

- SymInt now has a constructor that takes SymNode.  Note that this
  constructor is ambiguous if you pass in a subclass of SymNode,
  so an explicit downcast is necessary.  This means toSymFloat/toSymInt
  are no more.  This is a mild optimization as it means rvalue reference
  works automatically.

- We uniformly use the caster for c10::SymInt/SymFloat, rather than
  going the long way via the SymIntNode/SymFloatNode.

- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
  functions, pretty sure this doesn't do anything.

- guard_int is now a free function, since to guard on an int you cannot
  assume the method exists.  A function can handle both int and SymInt
  inputs.

- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
  ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
  plain methods; this is to help avoid confusion between the two types.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
This commit is contained in:
Edward Z. Yang
2022-10-27 13:49:11 -07:00
committed by PyTorch MergeBot
parent 2205f56f46
commit 1ff52225f1
54 changed files with 732 additions and 1439 deletions

View File

@ -1 +1 @@
095ee628212f0235ad0d6908bdd514123639fc86
1e9b8bdc75114ac6c16305c970be37a1cd2fdb1c

View File

@ -439,7 +439,7 @@ command = [
"""--error-description=\
This line has an isinstance call that directly refers to \
int or float. This is error-prone because you may also \
have wanted to allow SymIntNode or SymFloatNode in your test. \
have wanted to allow SymInt or SymFloat in your test. \
To suppress this lint, use an appropriate type alias defined \
in torch._prims_common; use IntLike/FloatLike when you would accept \
both regular and symbolic numbers, Dim for ints representing \

View File

@ -95,7 +95,7 @@ c10::SymInt get_nbytes(const Tensor& value) {
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
// Today, the two implementations of SymInt are in Python (proxy tensor),
// and lazy tensor (LTC/XLA).
// LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl).
// LTC hasn't implemented SymInt support yet though
// Once it does, we should remove this check.
if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes();

View File

@ -562,7 +562,7 @@ public:
IValue(c10::SymInt i) {
if (i.is_symbolic()) {
tag = Tag::SymInt;
payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release();
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
} else {
tag = Tag::Int;
payload.u.as_int = i.as_int_unchecked();
@ -578,7 +578,7 @@ public:
IValue(c10::SymFloat i) {
if (i.is_symbolic()) {
tag = Tag::SymFloat;
payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release();
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
} else {
tag = Tag::Double;
payload.u.as_double = i.as_float_unchecked();
@ -812,10 +812,10 @@ public:
// for both SymFloat and double
if (s.isSymInt()) {
tag = Tag::SymInt;
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release();
payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
} else if (s.isSymFloat()) {
tag = Tag::SymFloat;
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release();
payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
} else if (s.isFloatingPoint()) {
tag = Tag::Double;
payload.u.as_double = s.toDouble();

View File

@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& {
inline c10::SymInt IValue::toSymInt() const {
AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
if (isSymInt()) {
return c10::SymInt::toSymInt(toIntrusivePtr<c10::SymIntNodeImpl>());
return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
} else {
return c10::SymInt(payload.u.as_int);
}
@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const {
inline c10::SymFloat IValue::toSymFloat() const {
AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
if (isSymFloat()) {
return c10::SymFloat::toSymFloat(toIntrusivePtr<c10::SymFloatNodeImpl>());
return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
} else {
return c10::SymFloat(payload.u.as_double);
}

View File

@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type {
return "SymInt";
}
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
// TODO: will become a Union[SymIntNodeImpl|int] in the near future
return "int";
}
static const TypeKind Kind = TypeKind::SymIntType;

View File

@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) {
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
}
TEST(TestSymInt, Basic) {
Scalar foo;
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
foo = Scalar(a_impl->toSymInt());
ASSERT_EQ(a_impl.use_count(), 2);
Scalar bar{foo};
ASSERT_EQ(a_impl.use_count(), 3);
auto baz = bar;
ASSERT_EQ(a_impl.use_count(), 4);
auto foo2 = std::move(bar);
ASSERT_EQ(a_impl.use_count(), 4);
ASSERT_TRUE(foo2.isSymInt());
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_TRUE(bar.isIntegral(false));
foo2 = SymInt(4);
ASSERT_FALSE(foo2.isSymInt());
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
foo2 = foo2;
ASSERT_FALSE(foo2.isSymInt());
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
ASSERT_EQ(a_impl.use_count(), 3);
ASSERT_THROW(foo.to<double>(), c10::Error);
Scalar int_s = 3;
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
}

View File

@ -958,6 +958,7 @@ libtorch_python_core_sources = [
"torch/csrc/utils/object_ptr.cpp",
"torch/csrc/utils/python_arg_parser.cpp",
"torch/csrc/utils/python_dispatch.cpp",
"torch/csrc/utils/python_symnode.cpp",
"torch/csrc/utils/structseq.cpp",
"torch/csrc/utils/tensor_apply.cpp",
"torch/csrc/utils/tensor_dtypes.cpp",

View File

@ -92,8 +92,8 @@ class C10_API Scalar {
SymInt toSymInt() const {
if (Tag::HAS_si == tag) {
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy(
static_cast<SymIntNodeImpl*>(v.p)));
return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
static_cast<SymNodeImpl*>(v.p)));
} else {
return toLong();
}
@ -101,9 +101,8 @@ class C10_API Scalar {
SymFloat toSymFloat() const {
if (Tag::HAS_sd == tag) {
return c10::SymFloat::toSymFloat(
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy(
static_cast<SymFloatNodeImpl*>(v.p)));
return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
static_cast<SymNodeImpl*>(v.p)));
} else {
return toDouble();
}

View File

@ -1,32 +1,27 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
namespace c10 {
SymFloatNode SymFloat::toSymFloatNodeImpl() const {
SymNode SymFloat::toSymNodeImpl() const {
TORCH_CHECK(is_symbolic());
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
return SymNode::reclaim_copy(toSymNodeImplUnowned());
}
static std::array<SymFloatNode, 2> normalize_symfloats(
SymFloat a_,
SymFloat b_) {
SymFloatNode a, b;
static std::array<SymNode, 2> normalize_symfloats(SymFloat a_, SymFloat b_) {
SymNode a, b;
if (a_.is_symbolic())
a = a_.toSymFloatNodeImpl();
a = a_.toSymNodeImpl();
if (b_.is_symbolic())
b = b_.toSymFloatNodeImpl();
b = b_.toSymNodeImpl();
SymFloatNodeImpl* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match
SymNodeImpl* common = a ? a.get() : b.get();
if (!a) {
a = common->wrap(a_.as_float_unchecked());
a_.toSymFloat(a); //
a = common->wrap_float(a_.as_float_unchecked());
}
if (!b) {
b = common->wrap(b_.as_float_unchecked());
b_.toSymFloat(b);
b = common->wrap_float(b_.as_float_unchecked());
}
return {a, b};
}
@ -36,7 +31,7 @@ SymFloat SymFloat::operator+(SymFloat sci) const {
return SymFloat(data_ + sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->add(res[1]));
return SymFloat(res[0]->add(res[1]));
}
SymFloat SymFloat::operator-(SymFloat sci) const {
@ -44,7 +39,7 @@ SymFloat SymFloat::operator-(SymFloat sci) const {
return SymFloat(data_ - sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->sub(res[1]));
return SymFloat(res[0]->sub(res[1]));
}
SymFloat SymFloat::operator*(SymFloat sci) const {
@ -52,7 +47,7 @@ SymFloat SymFloat::operator*(SymFloat sci) const {
return SymFloat(data_ * sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->mul(res[1]));
return SymFloat(res[0]->mul(res[1]));
}
SymFloat SymFloat::operator/(SymFloat sci) const {
@ -60,16 +55,12 @@ SymFloat SymFloat::operator/(SymFloat sci) const {
return SymFloat(data_ / sci.data_);
}
auto res = normalize_symfloats(*this, sci);
return SymFloat::toSymFloat(res[0]->truediv(res[1]));
}
c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
return c10::SymFloat(std::move(sin_sp));
return SymFloat(res[0]->truediv(res[1]));
}
std::ostream& operator<<(std::ostream& os, SymFloat s) {
if (s.is_symbolic()) {
os << s.toSymFloatNodeImpl()->str();
os << s.toSymNodeImpl()->str();
} else {
os << s.as_float_unchecked();
}

View File

@ -1,6 +1,6 @@
#pragma once
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
@ -14,20 +14,21 @@ namespace c10 {
class C10_API SymFloat {
public:
/*implicit*/ SymFloat(double d) : data_(d){};
SymFloat(SymFloatNode ptr)
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)){};
SymFloat(SymNode ptr)
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_float());
};
SymFloat() : data_(0.0) {}
SymFloatNodeImpl* toSymFloatNodeImplUnowned() const {
SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get();
}
SymFloatNodeImpl* release() && {
SymNodeImpl* release() && {
return std::move(ptr_).release();
}
SymFloatNode toSymFloatNodeImpl() const;
static c10::SymFloat toSymFloat(SymFloatNode sin);
SymNode toSymNodeImpl() const;
double expect_float() const {
TORCH_CHECK(!is_symbolic());
@ -53,7 +54,7 @@ class C10_API SymFloat {
private:
// TODO: optimize to union
double data_;
SymFloatNode ptr_;
SymNode ptr_;
};
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);

View File

@ -1,20 +0,0 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/core/SymIntNodeImpl.h>
namespace c10 {
c10::SymFloat SymFloatNodeImpl::toSymFloat() {
auto sit_sp = SymFloatNode::reclaim_copy(this);
return SymFloat::toSymFloat(sit_sp);
}
c10::SymIntNode SymFloatNodeImpl::ceil() {
TORCH_CHECK(false, "NYI");
}
c10::SymIntNode SymFloatNodeImpl::floor() {
TORCH_CHECK(false, "NYI");
}
} // namespace c10

View File

@ -1,76 +0,0 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <memory>
#include <mutex>
#include <vector>
namespace c10 {
class SymIntNodeImpl;
using SymIntNode = c10::intrusive_ptr<SymIntNodeImpl>;
class SymFloat;
class SymFloatNodeImpl;
using SymFloatNode = c10::intrusive_ptr<SymFloatNodeImpl>;
class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target {
public:
c10::SymFloat toSymFloat();
virtual ~SymFloatNodeImpl(){};
template <typename T>
c10::intrusive_ptr<T> dyn_cast() const {
return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
}
virtual SymFloatNode wrap(double num) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode add(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode sub(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode mul(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode truediv(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode pow(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
}
virtual SymFloatNode eq(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode ne(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode gt(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode lt(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode le(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode ge(const SymFloatNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode ceil();
virtual SymIntNode floor();
virtual std::string str() {
TORCH_CHECK(false, "NYI");
};
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
};
};
} // namespace c10

View File

@ -1,47 +1,46 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
namespace c10 {
static std::array<SymIntNode, 2> normalize_symints(SymInt a_, SymInt b_) {
SymIntNode a, b;
static std::array<SymNode, 2> normalize_symints(SymInt a_, SymInt b_) {
SymNode a, b;
if (a_.is_symbolic())
a = a_.toSymIntNodeImpl();
a = a_.toSymNodeImpl();
if (b_.is_symbolic())
b = b_.toSymIntNodeImpl();
b = b_.toSymNodeImpl();
SymIntNodeImpl* common = a ? a.get() : b.get();
SymNodeImpl* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match
if (!a) {
a = common->wrap(a_.as_int_unchecked());
a_.toSymInt(a); //
a = common->wrap_int(a_.as_int_unchecked());
}
if (!b) {
b = common->wrap(b_.as_int_unchecked());
b_.toSymInt(b);
b = common->wrap_int(b_.as_int_unchecked());
}
return {a, b};
}
SymIntNode SymInt::toSymIntNodeImpl() const {
SymNode SymInt::toSymNodeImpl() const {
TORCH_CHECK(is_symbolic());
return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned());
return SymNode::reclaim_copy(toSymNodeImplUnowned());
}
c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) {
SymInt::SymInt(SymNode sin_sp) {
TORCH_CHECK(sin_sp->is_int());
auto ptr = static_cast<uint64_t>(
reinterpret_cast<uintptr_t>(static_cast<void*>(sin_sp.release())));
auto rep = (ptr & ~MASK) | IS_SYM;
return c10::SymInt(UNCHECKED, static_cast<int64_t>(rep));
data_ = static_cast<int64_t>(rep);
}
int64_t SymInt::guard_int(const char* file, int64_t line) const {
if (!is_symbolic()) {
return data_;
}
SymIntNode a = toSymIntNodeImpl();
SymNode a = toSymNodeImpl();
return a->guard_int(file, line);
}
@ -49,7 +48,7 @@ SymInt::operator SymFloat() const {
if (!is_symbolic()) {
return SymFloat(double(data_));
}
return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float());
return SymFloat(toSymNodeImpl()->sym_float());
}
SymInt SymInt::operator+(SymInt sci) const {
@ -57,7 +56,7 @@ SymInt SymInt::operator+(SymInt sci) const {
return SymInt(data_ + sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->add(res[1]));
return SymInt(res[0]->add(res[1]));
}
SymInt SymInt::operator-(SymInt sci) const {
@ -65,7 +64,7 @@ SymInt SymInt::operator-(SymInt sci) const {
return SymInt(data_ - sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->sub(res[1]));
return SymInt(res[0]->sub(res[1]));
}
SymInt SymInt::operator*(SymInt sci) const {
@ -73,7 +72,7 @@ SymInt SymInt::operator*(SymInt sci) const {
return SymInt(data_ * sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->mul(res[1]));
return SymInt(res[0]->mul(res[1]));
}
SymInt SymInt::operator/(SymInt sci) const {
@ -81,7 +80,7 @@ SymInt SymInt::operator/(SymInt sci) const {
return SymInt(data_ / sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->floordiv(res[1]));
return SymInt(res[0]->floordiv(res[1]));
}
SymInt SymInt::operator%(SymInt sci) const {
@ -89,7 +88,7 @@ SymInt SymInt::operator%(SymInt sci) const {
return SymInt(data_ % sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->mod(res[1]));
return SymInt(res[0]->mod(res[1]));
}
bool SymInt::operator==(SymInt sci) const {
@ -141,14 +140,14 @@ SymInt SymInt::min(SymInt sci) const {
return std::min(data_, sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->min(res[1]));
return SymInt(res[0]->min(res[1]));
}
SymInt SymInt::max(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return std::max(data_, sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->max(res[1]));
return SymInt(res[0]->max(res[1]));
}
void SymInt::operator*=(SymInt sci) {
@ -193,7 +192,7 @@ SymInt SymInt::operator*(int64_t sci) const {
std::ostream& operator<<(std::ostream& os, SymInt s) {
if (s.is_symbolic()) {
os << s.toSymIntNodeImpl()->str();
os << s.toSymNodeImpl()->str();
} else {
os << s.as_int_unchecked();
}
@ -202,7 +201,7 @@ std::ostream& operator<<(std::ostream& os, SymInt s) {
SymInt operator-(SymInt s) {
if (s.is_symbolic()) {
return SymInt::toSymInt(s.toSymIntNodeImpl()->neg());
return SymInt(s.toSymNodeImpl()->neg());
} else {
return SymInt(-s.as_int_unchecked());
}

View File

@ -1,6 +1,6 @@
#pragma once
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
@ -12,24 +12,19 @@ namespace c10 {
class SymFloat;
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
// represent concrete dimension values.
// SymInt represents either a regular int64_t, or a symbolic integer
// (represented in a type erased way as SymNode). The intention is for SymInt
// to represent symbolic sizes that arise when doing shape computation in
// operator kernels. This allows for tracing through programs without baking in
// concrete sizes into kernel calls.
//
// `SymInt` is also a data type in Pytorch that can be used in function schemas
// to enable tracing.
// SymInt has an API equivalent to int64_t. In particular, it is a value type.
// Internally, SymInt is represented in a clever packed way, so that it only
// occupies one word of space; but morally, it is a union between an int64_t
// and an intrusive pointer to SymNodeImpl.
//
// `SymInt` is introduced to enable tracing arithmetic
// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will
// allow LTC and AOTAutograd representing dynamic shapes in expression graphs
// faithfully without baking in concrete dimension values.
//
// To trace the operations, SymInt will overload arithmetic operators (e.g. +,
// -, *) and will provide overloads taking SymInt for commonly used math
// functions.
//
// SymInt will be extenteded to represent a union structure Union[int64_t,
// SymIntNodeImpl*] which will be implemented as a single packed int64_t field
// named data_.
// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where
// is_int() returns true
class C10_API SymInt {
public:
@ -44,6 +39,7 @@ class C10_API SymInt {
TORCH_CHECK(!is_symbolic());
};
SymInt() : data_(0) {}
SymInt(SymNode n);
// unchecked c-tor accepting raw `data_`
// One appropriate use for this is when you are constructing a symint
@ -55,7 +51,7 @@ class C10_API SymInt {
// temporary and then use the move constructor/assignment
SymInt(const SymInt& s) : data_(0) {
if (s.is_symbolic()) {
*this = SymInt::toSymInt(s.toSymIntNodeImpl());
*this = SymInt(s.toSymNodeImpl());
} else {
data_ = s.data_;
}
@ -67,7 +63,7 @@ class C10_API SymInt {
SymInt& operator=(const SymInt& s) {
if (this != &s) {
if (s.is_symbolic()) {
*this = SymInt::toSymInt(s.toSymIntNodeImpl());
*this = SymInt(s.toSymNodeImpl());
} else {
data_ = s.data_;
}
@ -76,7 +72,7 @@ class C10_API SymInt {
}
SymInt& operator=(SymInt&& s) {
if (this != &s) {
release_(); // release the current SymIntNode if any
release_(); // release the current SymNode if any
data_ = s.data_;
if (s.is_symbolic())
s.data_ = 0;
@ -86,31 +82,31 @@ class C10_API SymInt {
SymInt clone() const {
if (is_symbolic()) {
return toSymIntNodeImplUnowned()->clone()->toSymInt();
return SymInt(toSymNodeImplUnowned()->clone());
}
return *this;
}
SymIntNodeImpl* toSymIntNodeImplUnowned() const {
SymNodeImpl* toSymNodeImplUnowned() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic());
uint64_t unextended_bits = static_cast<uint64_t>(data_) & ~MASK;
uint64_t sign_bit_mask = 1ULL << (62 - 1);
// https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c
uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask;
return static_cast<SymIntNodeImpl*>(
return static_cast<SymNodeImpl*>(
reinterpret_cast<void*>(static_cast<uintptr_t>(extended_bits)));
}
void release_() {
if (is_symbolic()) {
SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal
SymNode::reclaim(toSymNodeImplUnowned()); // steal
}
}
SymIntNodeImpl* release() && {
SymNodeImpl* release() && {
#ifndef C10_MOBILE
TORCH_INTERNAL_ASSERT(is_symbolic());
auto* r = toSymIntNodeImplUnowned();
auto* r = toSymNodeImplUnowned();
data_ = 0; // transfer ownership
return r;
#else
@ -118,8 +114,7 @@ class C10_API SymInt {
#endif
}
SymIntNode toSymIntNodeImpl() const;
static c10::SymInt toSymInt(SymIntNode sin);
SymNode toSymNodeImpl() const;
~SymInt() {
release_();

View File

@ -1,11 +0,0 @@
#include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h>
namespace c10 {
c10::SymInt SymIntNodeImpl::toSymInt() {
auto sit_sp = SymIntNode::reclaim_copy(this);
return SymInt::toSymInt(sit_sp);
}
} // namespace c10

3
c10/core/SymNodeImpl.cpp Normal file
View File

@ -0,0 +1,3 @@
#include <c10/core/SymNodeImpl.h>
namespace c10 {} // namespace c10

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
@ -10,13 +9,12 @@
namespace c10 {
class SymInt;
class SymIntNodeImpl;
class SymNodeImpl;
using SymNode = c10::intrusive_ptr<SymNodeImpl>;
class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
public:
c10::SymInt toSymInt();
virtual ~SymIntNodeImpl(){};
virtual ~SymNodeImpl(){};
template <typename T>
c10::intrusive_ptr<T> dyn_cast() const {
@ -24,66 +22,87 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
}
// these could be pure virtual when we implement LTC versions
virtual SymIntNode add(const SymIntNode& other) {
virtual bool is_int() {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode sub(const SymIntNode& other) {
virtual bool is_float() {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode mul(const SymIntNode& other) {
virtual SymNode add(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode truediv(const SymIntNode& other) {
virtual SymNode sub(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode floordiv(const SymIntNode& other) {
virtual SymNode mul(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode mod(const SymIntNode& other) {
virtual SymNode truediv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode eq(const SymIntNode& other) {
virtual SymNode pow(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode ne(const SymIntNode& other) {
virtual SymNode floordiv(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode gt(const SymIntNode& other) {
virtual SymNode mod(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode lt(const SymIntNode& other) {
virtual SymNode eq(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode le(const SymIntNode& other) {
virtual SymNode ne(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode ge(const SymIntNode& other) {
virtual SymNode gt(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode ceil() {
virtual SymNode lt(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode neg() {
virtual SymNode le(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode min(const SymIntNode& other) {
virtual SymNode ge(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode max(const SymIntNode& other) {
virtual SymNode ceil() {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode clone() {
virtual SymNode floor() {
TORCH_CHECK(false, "NYI");
};
virtual SymFloatNode sym_float() {
virtual SymNode neg() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode min(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode max(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode clone() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_int() {
TORCH_CHECK(false, "NYI");
}
virtual SymIntNode wrap(int64_t num) {
virtual SymNode sym_float() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode wrap_int(int64_t num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_float(double num) {
TORCH_CHECK(false, "NYI");
};
virtual int64_t guard_int(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual double guard_float(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual int64_t int_() {
TORCH_CHECK(false, "NYI");
};

View File

@ -1,7 +1,7 @@
#include <gtest/gtest.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
using namespace c10;
#ifndef C10_MOBILE
@ -20,12 +20,6 @@ TEST(SymIntTest, ConcreteInts) {
check(-4611686018427387904LL);
}
TEST(SymIntTest, AddNode) {
auto n = c10::make_intrusive<SymIntNodeImpl>();
auto i = n->toSymInt();
EXPECT_TRUE(i.is_symbolic());
}
TEST(SymIntTest, CheckRange) {
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
}

View File

@ -335,8 +335,8 @@ coverage_ignore_classes = [
"Quantize",
# torch.utils.backcompat
"Warning",
"SymIntNode",
"SymFloatNode",
"SymInt",
"SymFloat",
]
# The suffix(es) of source filenames.

View File

@ -605,7 +605,7 @@ class PytreeThunk:
return x
return pytree.tree_unflatten(x, self.spec)
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymIntNode, torch.SymFloatNode]
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymInt, torch.SymFloat]
def aot_function(

View File

@ -209,7 +209,7 @@ def _tensor_nbytes(numel, dtype):
def _size_of(node: fx.Node) -> int:
def to_size_hint(s):
if isinstance(s, torch.SymIntNode):
if isinstance(s, torch.SymInt):
py_s = s.get_pyobj()
return py_s.shape_env.size_hint(py_s.expr)
assert isinstance(s, int)

View File

@ -18,6 +18,8 @@ cond = PyOperator('cond')
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
def _unwrap_proxy(e):
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
return e
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
assert isinstance(operands, list), "Cond operands must be a list of tensors"

View File

@ -1447,35 +1447,29 @@ TEST(TestSymInt, AddSymbolicInt) {
}
#ifndef C10_MOBILE
TEST(TestSymInt, TestIntrusive) {
auto a = c10::make_intrusive<c10::SymIntNodeImpl>();
auto b = c10::make_intrusive<c10::SymIntNodeImpl>();
ASSERT_EQ(a.use_count(), 1);
ASSERT_EQ(b.use_count(), 1);
auto as = a->toSymInt();
auto bs = b->toSymInt();
ASSERT_EQ(a.use_count(), 2);
ASSERT_EQ(b.use_count(), 2);
as = bs;
ASSERT_EQ(a.use_count(), 1);
ASSERT_EQ(b.use_count(), 3);
}
class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
class TestSymNodeImpl : public c10::SymNodeImpl {
public:
TestSymIntNodeImpl(int64_t i) : i_(i) {}
explicit TestSymNodeImpl(int64_t i) : i_(i) {}
bool is_int() override {
return true;
};
bool is_float() override {
return false;
};
bool bool_() override {
return static_cast<bool>(i_);
};
#define OPDEF3(NAME, OP, RET) \
RET NAME(const c10::SymIntNode& other) override { \
return make_intrusive<TestSymIntNodeImpl>( \
this->i_ OP dynamic_cast<TestSymIntNodeImpl*>(other.get())->i_); \
#define OPDEF3(NAME, OP, RET) \
RET NAME(const c10::SymNode& other) override { \
return make_intrusive<TestSymNodeImpl>( \
this->i_ OP dynamic_cast<TestSymNodeImpl*>(other.get())->i_); \
}
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymIntNode)
#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode)
OPDEF2(add, +)
OPDEF2(sub, -)
OPDEF2(mul, *)
@ -1494,17 +1488,19 @@ class TestSymIntNodeImpl : public c10::SymIntNodeImpl {
int64_t i_;
};
TEST(TestSymInt, TestSymIntToSymIntNodeDispatch) {
TEST(TestSymInt, TestSymIntToSymNodeDispatch) {
auto get = [](c10::SymInt si) {
auto node = si.toSymIntNodeImpl();
return dynamic_cast<TestSymIntNodeImpl*>(node.get())->i_;
auto node = si.toSymNodeImpl();
return dynamic_cast<TestSymNodeImpl*>(node.get())->i_;
};
std::vector<int64_t> inputs{0, 1, -1, 4, -4, 777, -777};
for (auto i : inputs) {
for (auto j : inputs) {
auto a = c10::make_intrusive<TestSymIntNodeImpl>(i)->toSymInt();
auto b = c10::make_intrusive<TestSymIntNodeImpl>(j)->toSymInt();
auto a = c10::SymInt(
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(i)));
auto b = c10::SymInt(
static_cast<SymNode>(c10::make_intrusive<TestSymNodeImpl>(j)));
ASSERT_EQ(get(a + b), i + j);
ASSERT_EQ(get(a - b), i - j);
ASSERT_EQ(get(a * b), i * j);

View File

@ -12,8 +12,9 @@ import itertools
import io
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode
from torch.utils._python_dispatch import TorchDispatchMode
from torch import SymInt
aten = torch.ops.aten
@ -116,9 +117,6 @@ def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
def create_symint(shape_env, i):
return shape_env.create_symintnode(shape_env.create_symbol(i))
@ -156,8 +154,8 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], PySymInt))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
@ -165,17 +163,17 @@ class TestPySymInt(TestCase):
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
self.assertTrue(isinstance(x.size(2), SymInt))
offset = create_symint(shape_env, 2)
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 2)
offset = 2
@ -267,7 +265,7 @@ class TestPySymInt(TestCase):
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
self.assertIsInstance(x.stride()[0], SymInt)
@skipIfNoSympy
def test_size_expressions(self):
@ -290,7 +288,7 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0])
self.assertTrue(isinstance(r, torch.SymFloatNode))
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
@skipIfNoSympy
def test_aten_ops(self):
@ -320,13 +318,13 @@ class TestPySymInt(TestCase):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
self.assertIsInstance(r.shape[0], SymInt)
@skipIfNoSympy
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(a0.guard_int(), 2)
self.assertEqual(guard_int(a0), 2)
self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)")
@skipIfNoSympy
@ -347,7 +345,9 @@ class TestPySymInt(TestCase):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
sym_int_encountered = kwargs["alpha"] is a0
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)

View File

@ -1,391 +0,0 @@
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: jit"]
from torch._C import _disabled_torch_function_impl
import torch.fx
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo
import unittest
import torch
import operator
import itertools
import io
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
try:
import sympy
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env, storage_offset=0):
sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset)
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))
def create_symint(shape_env, i):
return shape_env.create_symintnode(shape_env.create_symbol(i))
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
@skipIfNoSympy
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
@skipIfNoSympy
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
@skipIfNoSympy
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], PySymInt))
self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS))
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS))
offset = create_symint(shape_env, 2)
y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS))
self.assertTrue(y.storage_offset() == 2)
offset = 2
z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset)
self.assertTrue(isinstance(z.storage_offset(), int))
self.assertTrue(z.storage_offset() == 2)
@skipIfNoSympy
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
@skipIfNoSympy
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
@skipIfNoSympy
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
@skipIfNoSympy
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op = shape_env.guards[0][0]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
@skipIfNoSympy
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0])
self.assertTrue(isinstance(r, torch.SymFloatNode))
@skipIfNoSympy
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
@skipIfNoSympy
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS)
@skipIfNoSympy
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(a0.guard_int(), 2)
self.assertEqual(str(shape_env.guards[0][0]), "s0")
self.assertEqual(shape_env.guards[0][1], 2)
@skipIfNoSympy
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
@skipIfNoSympy
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
sym_int_encountered = kwargs["alpha"] is a0
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
@skipIfNoSympy
@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
def test_print_readable_with_symints(self, mock_stdout):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
fx_g.print_readable()
self.assertExpectedInline(mock_stdout.getvalue().strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]):
# No stacktrace found for following nodes
sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0]
getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
if __name__ == '__main__':
run_tests()

View File

@ -875,8 +875,7 @@ def forward(self, a_1):
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0)
sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None
pow_1 = sym_float ** 0.5; sym_float = None
pow_1 = sym_size ** 0.5; sym_size = None
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
return div""")
@ -949,7 +948,7 @@ def forward(self, a_1):
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr)
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].node.expr)
def test_metadata_fresh(self):
def f(x):

View File

@ -207,8 +207,8 @@ class TestPublicBindings(TestCase):
"StreamObjType",
"StringType",
"SUM",
"SymFloatNode",
"SymIntNode",
"SymFloat",
"SymInt",
"TensorType",
"ThroughputBenchmark",
"TracingState",

View File

@ -291,7 +291,7 @@ PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i : c10::irange(prop.size())) {
auto si = prop[i];
if (si.is_symbolic()) {
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
auto py_symint = py::cast(si).release().ptr();
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
} else {
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked()));
@ -313,7 +313,7 @@ return PyLong_FromUnsignedLong((int64_t) prop);
"""
GETTER_BODY_SYMINT = """\
return prop.is_symbolic() ? py::cast(prop.toSymIntNodeImpl()).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked());
"""
GETTER_BODY_DOUBLE = """\

View File

@ -5,7 +5,7 @@
#include <Python.h>
#include <ATen/ATen.h>
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/python_cpp_function.h"
#include <torch/csrc/autograd/python_variable.h>

View File

@ -240,12 +240,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
if (jit::tracer::isTracing()) {
return wrap(jit::tracer::getNumelOf(self_));
} else {
auto si = self_.sym_numel();
if (si.is_symbolic()) {
return py::cast(si.toSymIntNodeImpl()).release().ptr();
} else {
return THPUtils_packInt64(si.as_int_unchecked());
}
return py::cast(self_.sym_numel()).release().ptr();
}
END_HANDLE_TH_ERRORS
}

View File

@ -722,7 +722,7 @@ def gen_pyi(
binop += "_"
out_suffix = ""
unsorted_tensor_method_hints[binop].append(
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode]{})"
"def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})"
" -> Tensor: ...".format(binop, out_suffix)
)
for binop in ["add", "sub"]:
@ -732,7 +732,7 @@ def gen_pyi(
binop += "_"
out_suffix = ""
unsorted_tensor_method_hints[binop].append(
"def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode], "
"def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], "
"*, alpha: Optional[Number]=1{})"
" -> Tensor: ...".format(binop, out_suffix)
)

View File

@ -169,20 +169,6 @@ class Future(object):
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
class SymIntNode(object):
def get_pyobj(self) -> Any: ...
@staticmethod
def new_symint(obj) -> SymIntNode: ...
class SymFloatNode(object):
def get_pyobj(self) -> Any: ...
@staticmethod
def new_symfloat(obj) -> SymFloatNode: ...
def __ceil__(self) -> SymIntNode: ...
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
class MobileOptimizerType:
...

View File

@ -47,7 +47,7 @@ __all__ = [
'is_deterministic_algorithms_warn_only_enabled',
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled',
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
]
################################################################################
@ -196,6 +196,67 @@ else:
if TYPE_CHECKING:
import torch._C as _C
class SymInt:
"""
Like an int (including magic methods), but redirects all operations on the
wrapped node. This is used in particular to symbolically record operations
in the symbolic shape workflow.
"""
def __init__(self, node):
from torch.fx.experimental.symbolic_shapes import SymNode
assert isinstance(node, SymNode)
# This field MUST be named node; C++ binding code assumes that this
# class has a field named node that stores SymNode
self.node = node
# Magic methods installed later
def __bool__(self):
return self.node.bool_()
def __int__(self):
return self.node.int_()
def __sym_float__(self):
return SymFloat(self.node.sym_float())
def __repr__(self):
return self.node.str()
# For BC; direct access of node is OK too
def get_pyobj(self):
return self.node
class SymFloat:
"""
Like an float (including magic methods), but redirects all operations on the
wrapped node. This is used in particular to symbolically record operations
in the symbolic shape workflow.
"""
def __init__(self, node):
from torch.fx.experimental.symbolic_shapes import SymNode
assert isinstance(node, SymNode)
# This field MUST be named node; C++ binding code assumes that this
# class has a field named node that stores SymNode
self.node = node
# Magic methods installed later
def __bool__(self):
return self.node.bool_()
def __sym_int__(self):
return SymInt(self.node.sym_int())
def __repr__(self):
return self.node.str()
# For BC; direct access of node is OK too
def get_pyobj(self):
return self.node
# Check to see if we can load C extensions, and if not provide some guidance
# on what the problem might be.
try:
@ -941,7 +1002,6 @@ from ._linalg_utils import ( # type: ignore[misc]
lstsq,
)
def _register_device_module(device_type, module):
r"""Register an external runtime module of the specific :attr:`device_type`
supported by torch.
@ -971,3 +1031,6 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
import torch.cuda._sanitizer as csan
csan.enable_cuda_sanitizer()
# Populate magic methods on SymInt and SymFloat
import torch.fx.experimental.symbolic_shapes

View File

@ -337,7 +337,7 @@ class TensorVariable(VariableTracker):
from . import UserDefinedObjectVariable
return UserDefinedObjectVariable(example_value)
elif isinstance(example_value, torch.SymIntNode):
elif isinstance(example_value, torch.SymInt):
proxy.node.meta["example_value"] = example_value
return cls(proxy, **options)
else:

View File

@ -40,11 +40,9 @@ class GraphLowering(torch.fx.Interpreter):
else:
size, stride = self._shape_env.create_symbolic_sizes_strides(ex)
size = [
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in size
]
size = [i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in size]
stride = [
i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in stride
i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in stride
]
return size, stride

View File

@ -392,8 +392,8 @@ def _elementwise_meta(
# Number case
# NOTE: this case is not currently exercised
# TODO: fix number type promotion (bool, complex->float)
assert not isinstance(number, torch.SymIntNode), "NYI"
assert not isinstance(number, torch.SymFloatNode), "NYI"
assert not isinstance(number, torch.SymInt), "NYI"
assert not isinstance(number, torch.SymFloat), "NYI"
return TensorMeta(number)
@ -932,7 +932,7 @@ bitwise_xor = _make_elementwise_binary_prim(
# div prim performs truncation division on integer inputs
# and true division for floating and complex inputs
def _div_aten(a, b):
is_integral = isinstance(a, (bool, int, torch.SymIntNode)) or (
is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
)

View File

@ -42,18 +42,18 @@ ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
StrideType = Union[List[int], Tuple[int, ...]]
DimsType = Union[int, List[int], Tuple[int, ...]]
DimsSequenceType = Union[List[int], Tuple[int, ...]]
# TODO: Type[torch.SymIntNode], Type[torch.SymFloatNode]
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
# TODO: This needs a lot more type annotations
# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode]
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
NumberType = Union[bool, int, float, complex]
Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode)
Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
# I don't call it Integral because numbers.Integral includes bool, but IntLike
# does not
Dim = int
IntLike = (int, torch.SymIntNode)
FloatLike = (float, torch.SymFloatNode)
IntLike = (int, torch.SymInt)
FloatLike = (float, torch.SymFloat)
IntWithoutSymInt = int
FloatWithoutSymFloat = float
DeviceLikeType = Union[str, torch.device]
@ -1113,10 +1113,10 @@ class RETURN_TYPE(Enum):
# TODO: when NumberType contains the sym types, can simplify this
def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type:
if isinstance(x, torch.SymIntNode):
def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
if isinstance(x, torch.SymInt):
return int
elif isinstance(x, torch.SymFloatNode):
elif isinstance(x, torch.SymFloat):
return float
else:
return type(x)

View File

@ -656,7 +656,7 @@ class FakeTensorMode(TorchDispatchMode):
return args[0].fake_device
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs))
flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs))
has_symbolic_sizes = (
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors])
or len(flat_symints) > 0

View File

@ -59,7 +59,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
TORCH_CHECK(
!torch::jit::tracer::isTracing(),
"JIT Tracing of SymInts isn't supported");
auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr();
auto py_symint = py::cast(si).release().ptr();
if (!py_symint)
throw python_error();
PyTuple_SET_ITEM(ret.get(), i, py_symint);
@ -98,7 +98,7 @@ static PyObject* THPSize_pynew(
if (THPUtils_checkLong(item)) {
continue;
}
if (torch::is_symint_node(item)) {
if (torch::is_symint(item)) {
continue;
}
if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
@ -135,7 +135,7 @@ static PyObject* THPSize_repr(THPSize* self) {
auto item = PyTuple_GET_ITEM(self, i);
auto ih = py::handle(item);
repr += torch::is_symint_node(ih)
repr += torch::is_symint(ih)
? std::string(py::str(ih))
: std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
}

View File

@ -2646,9 +2646,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
"Cannot call numel on a tensor with symbolic shapes/strides");
return self->sym_numel_default();
}
return torch::is_symint_node(out)
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
: c10::SymInt{py::cast<int64_t>(out)};
return torch::is_symint(out) ? out.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
@ -2669,9 +2668,8 @@ c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
if (out.is(py::none())) {
return self->sym_storage_offset_default();
}
return torch::is_symint_node(out)
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
: c10::SymInt{py::cast<int64_t>(out)};
return torch::is_symint(out) ? out.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
@ -2701,9 +2699,8 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
py::list symints;
for (auto it = out.begin(); it != out.end(); it++) {
auto elm = *it;
auto si = torch::is_symint_node(elm)
? elm.cast<c10::SymIntNodeImpl*>()->toSymInt()
: c10::SymInt{py::cast<int64_t>(elm)};
auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(elm)};
symints.append(si.as_int_unchecked());
}

View File

@ -13,7 +13,7 @@
#if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH))
#include <torch/csrc/jit/codegen/onednn/interface.h>
#endif
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/irparser.h>
@ -99,7 +99,6 @@
#include <torch/csrc/jit/tensorexpr/tensorexpr_init.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <c10/core/SymFloat.h>
#include <c10/macros/Export.h>
#include <c10/util/irange.h>
#include <c10/util/signal_handler.h>
@ -126,249 +125,11 @@ using c10::Argument;
using c10::FunctionSchema;
using c10::SchemaArgType;
using c10::SchemaArgument;
using c10::SymFloat;
using c10::SymFloatNode;
using c10::SymIntNode;
using c10::SymNode;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
using torch::utils::SchemaInfo;
static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) {
return torch::is_symint_node(b) ? b.cast<c10::SymIntNode>()
: a->wrap(b.cast<int64_t>());
}
static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) {
if (torch::is_symfloat_node(b)) {
return b.cast<c10::SymFloatNode>();
} else if (torch::is_symint_node(b)) {
return b.cast<c10::SymIntNode>()->sym_float();
} else {
return a->wrap(b.cast<double>());
}
}
class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
public:
PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
virtual SymIntNode clone() override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("clone")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode wrap(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__bool__")().is(py::handle(Py_True));
}
virtual int64_t guard_int(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
}
virtual int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__int__")().cast<int64_t>();
}
SymFloatNode sym_float() override;
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
}
virtual SymIntNode dispatch_common_(
const char* fname,
const SymIntNode& other) {
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode dispatch_common_(const char* fname) {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
virtual SymIntNode add(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode sub(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode mul(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymFloatNode truediv(const SymIntNode& other) override;
virtual SymIntNode floordiv(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode mod(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode eq(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode gt(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode lt(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode le(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode ge(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode min(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode max(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
virtual SymIntNode neg() override {
return dispatch_common_(__FUNCTION__);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl {
public:
PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
virtual SymFloatNode wrap(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap")(num);
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
virtual std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("__str__")().cast<std::string>();
}
SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) {
auto pother = dynamic_cast<PythonSymFloatNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
SymFloatNode add(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode sub(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode mul(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode truediv(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode pow(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode eq(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode gt(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode lt(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode le(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymFloatNode ge(const SymFloatNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
SymIntNode ceil() override;
SymIntNode floor() override;
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) {
auto pother = dynamic_cast<PythonSymIntNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("truediv")(pother->getPyObj());
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
SymFloatNode PythonSymIntNodeImpl::sym_float() {
py::gil_scoped_acquire acquire;
return c10::make_intrusive<PythonSymFloatNodeImpl>(
getPyObj().attr("__sym_float__")());
}
SymIntNode PythonSymFloatNodeImpl::ceil() {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("ceil")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
SymIntNode PythonSymFloatNodeImpl::floor() {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("floor")();
return c10::make_intrusive<PythonSymIntNodeImpl>(r);
}
namespace {
using autograd::variable_list;
@ -1381,276 +1142,41 @@ void initJITBindings(PyObject* module) {
}
});
auto symint_class =
py::class_<c10::SymIntNodeImpl, c10::SymIntNode>(m, "SymIntNode")
.def_static(
"new_symint",
[](py::object obj) -> c10::SymIntNode {
return c10::make_intrusive<PythonSymIntNodeImpl>(obj);
})
.def(
"get_pyobj",
[](c10::SymIntNode a) -> py::object {
if (auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get())) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();
})
.def(
"__add__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->add(snb);
})
.def(
"__radd__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->add(a);
})
.def(
"__sub__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->sub(snb);
})
.def(
"__rsub__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->sub(a);
})
.def(
"__mul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->mul(a);
})
.def(
"__truediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymIntNode(a, b);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymIntNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymIntNode(a, b);
return snb->truediv(a);
})
.def(
"__floordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->floordiv(snb);
})
.def(
"__rfloordiv__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->floordiv(a);
})
.def(
"__mod__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->mod(snb);
})
.def(
"__rmod__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return snb->mod(a);
})
.def(
"__pow__",
[](c10::SymIntNode a, py::object b) -> py::object {
if (PyFloat_Check(b.ptr())) {
auto float_a = a->sym_float();
return py::cast(
float_a->pow(float_a->wrap(py::cast<double>(b))));
}
// TODO: integer pow
return py::reinterpret_borrow<py::object>(Py_NotImplemented);
})
// TODO: rpow
.def(
"__eq__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](c10::SymIntNode a, py::object b) {
auto snb = toSymIntNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->ge(snb);
})
.def(
"__ceil__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__neg__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); })
.def(
"__min__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->min(snb);
})
.def(
"__max__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->max(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
// Intentionally don't set file line, as the Python backtrace matters
// more here
.def(
"guard_int",
[](c10::SymIntNode a) { return a->guard_int(nullptr, 0); })
.def(
"__sym_float__",
[](c10::SymIntNode a) {
// TODO: remove dynamic cast when sym_float is in base class
auto* psn = dynamic_cast<PythonSymIntNodeImpl*>(a.get());
TORCH_INTERNAL_ASSERT(psn);
return psn->sym_float();
})
.def("__str__", [](c10::SymIntNode a) { return a->str(); })
.def("__repr__", [](c10::SymIntNode a) { return a->str(); });
py::class_<c10::SymFloatNodeImpl, c10::SymFloatNode>(m, "SymFloatNode")
.def_static(
"new_symfloat",
[](py::object obj) -> c10::SymFloatNode {
return c10::make_intrusive<PythonSymFloatNodeImpl>(obj);
})
.def(
"__add__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->add(snb);
})
.def(
"__radd__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->add(a);
})
.def(
"__sub__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->sub(snb);
})
.def(
"__mul__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->mul(snb);
})
.def(
"__rmul__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->mul(a);
})
.def(
"__truediv__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->truediv(snb);
})
.def(
"__rtruediv__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->truediv(a);
})
.def(
"__eq__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->eq(snb);
})
.def(
"__gt__",
[](c10::SymFloatNode a, py::object b) {
auto snb = toSymFloatNode(a, b);
return a->gt(snb);
})
.def(
"__lt__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->lt(snb);
})
.def(
"__le__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->le(snb);
})
.def(
"__ge__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->ge(snb);
})
.def(
"__pow__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return a->pow(snb);
})
.def(
"__rpow__",
[](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode {
auto snb = toSymFloatNode(a, b);
return snb->pow(a);
})
.def(
"__ceil__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__floor__",
[](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); })
.def(
"get_pyobj",
[](c10::SymFloatNode a) -> py::object {
if (auto* psn = dynamic_cast<PythonSymFloatNodeImpl*>(a.get())) {
return py::reinterpret_borrow<py::object>(psn->getPyObj());
}
return py::none();
})
.def("__str__", [](c10::SymFloatNode a) { return a->str(); });
// NB: This isn't actually used for regular PyTorch symbolic tracing;
// XLA is what needs this
#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); })
#define SYMNODE_BINARY(n) \
.def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
auto symnode_class =
py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
// These DO NOT install magic methods; the SymInt/SymFloat wrapper in
// Python is responsible for this
SYMNODE_UNARY(clone)
// Named these for consistency with inner python class, but maybe
// should change the python side
SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_)
SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2(
__sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub)
SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow)
SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(
eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt)
SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min)
SYMNODE_BINARY(max) SYMNODE_UNARY(ceil)
SYMNODE_UNARY(floor) SYMNODE_UNARY(neg)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
})
.def(
"__str__",
[](c10::SymNode a) { return a->str(); })
.def("__repr__", [](c10::SymNode a) {
return a->str();
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")

View File

@ -80,10 +80,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr()));
} else if (THPUtils_checkDouble(obj.ptr())) {
scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr()));
} else if (torch::is_symint_node(py::handle(obj))) {
} else if (torch::is_symint(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(7777777);
} else if (torch::is_symfloat_node(py::handle(obj))) {
} else if (torch::is_symfloat(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
} else {
@ -161,12 +161,12 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
return py::cast<int64_t>(obj);
}
case TypeKind::SymIntType:
if (torch::is_symint_node(obj.ptr())) {
if (torch::is_symint(obj.ptr())) {
return py::cast<c10::SymInt>(obj);
}
return py::cast<int64_t>(obj);
case TypeKind::SymFloatType:
if (torch::is_symfloat_node(obj.ptr())) {
if (torch::is_symfloat(obj.ptr())) {
return py::cast<c10::SymFloat>(obj);
}
return py::cast<double>(obj);
@ -253,7 +253,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
bool is_symbolic = false;
for (auto it = obj.begin(); it != obj.end(); it++) {
auto elm = *it;
if (torch::is_symint_node(elm)) {
if (torch::is_symint(elm)) {
is_symbolic = true;
break;
}
@ -269,7 +269,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
for (auto it = obj.begin(); it != obj.end(); it++) {
auto elm = *it;
// TODO: what about SymInt conversion to SymFloat?
if (torch::is_symfloat_node(elm)) {
if (torch::is_symfloat(elm)) {
is_symbolic = true;
break;
}
@ -442,9 +442,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
} else if (PyComplex_CheckExact(obj.ptr())) {
auto c_obj = py::cast<std::complex<double>>(obj.ptr());
return static_cast<c10::complex<double>>(c_obj);
} else if (torch::is_symint_node(obj)) {
} else if (torch::is_symint(obj)) {
return py::cast<c10::SymInt>(obj);
} else if (torch::is_symfloat_node(obj)) {
} else if (torch::is_symfloat(obj)) {
return py::cast<c10::SymFloat>(obj);
} else {
throw py::cast_error(

View File

@ -136,10 +136,10 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
inline Value GetSymIntValue(c10::SymInt a) {
return Value(
a.is_symbolic() ? dynamic_cast<torch::lazy::SymIntNodeImpl*>(
a.toSymIntNodeImpl().get())
->node_
: MakeScalar(a.as_int_unchecked(), at::kLong),
a.is_symbolic()
? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get())
->node_
: MakeScalar(a.as_int_unchecked(), at::kLong),
0);
}

View File

@ -451,11 +451,11 @@ std::vector<Shape> compute_shape_expand(
std::vector<int64_t> target_size(_sizes.size());
for (const auto idx : c10::irange(_sizes.size())) {
if (_sizes[idx].is_symbolic()) {
c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl();
auto* lazySymIntNode =
dynamic_cast<torch::lazy::SymIntNodeImpl*>(symbolicIntNode.get());
TORCH_INTERNAL_ASSERT(lazySymIntNode);
auto size_node = lazySymIntNode->node_;
c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl();
auto* lazySymNode =
dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get());
TORCH_INTERNAL_ASSERT(lazySymNode);
auto size_node = lazySymNode->node_;
auto static_value =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
->getStaticValue();

View File

@ -4,7 +4,7 @@
#include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/backend/backend_data.h>

View File

@ -1,6 +1,6 @@
#pragma once
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
@ -10,12 +10,9 @@
namespace torch {
namespace lazy {
class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl {
class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
public:
SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
c10::SymIntNode add(const c10::SymIntNode& other) override {
TORCH_CHECK(false, "NYI");
}
SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
NodePtr node_;
};

View File

@ -685,7 +685,7 @@ static bool is_int_list(
// NB: do NOT check that the later arguments are ints, as this is
// BC-breaking for FX
for (int i = 1; i < len; i++) {
if (torch::is_symint_node(
if (torch::is_symint(
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
if (failed_idx != nullptr) {
*failed_idx = i;
@ -716,9 +716,9 @@ static bool is_int_list(
static bool is_int_or_symint(PyObject* obj) {
// THPUtils_checkIndex may call __index__ or __int__
// which may have side effects if obj is a symint node
// so we do `is_symint_node` check first
// so we do `is_symint` check first
// TODO: maybe we should be using checkLong here?
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj);
return torch::is_symint(py::handle(obj)) || THPUtils_checkIndex(obj);
}
static bool is_int_or_symint_list(
@ -1570,13 +1570,13 @@ at::Tensor PythonArgs::tensor_slow(int i) {
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
// because although Scalar supports SymInt/SymFloat, the subsequent
// conversion to Tensor does not. Instead, do it out of band.
} else if (torch::is_symint_node(py::handle(obj))) {
} else if (torch::is_symint(py::handle(obj))) {
save_symint = true;
// This scalar value doesn't matter, it shouldn't ever actually
// get read out. Make it a big and weird looking number to help
// people figure out if there's aproblem.
scalar = at::Scalar(7777777);
} else if (torch::is_symfloat_node(py::handle(obj))) {
} else if (torch::is_symfloat(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
} else {
@ -1633,11 +1633,11 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
return at::Scalar(THPUtils_unpackComplexDouble(arg));
}
if (torch::is_symint_node(arg)) {
if (torch::is_symint(arg)) {
return at::Scalar(py::cast<c10::SymInt>(arg));
}
if (torch::is_symfloat_node(arg)) {
if (torch::is_symfloat(arg)) {
return at::Scalar(py::cast<c10::SymFloat>(arg));
}

View File

@ -61,6 +61,7 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/python_symnode.h>
#include <torch/csrc/utils/six.h>
#include <ATen/PythonTorchFunctionTLS.h>
@ -69,7 +70,7 @@
#include <c10/util/irange.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymIntNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
#include <cstddef>
@ -78,30 +79,6 @@
#include <string>
#include <vector>
namespace torch {
inline bool is_symint_node(py::handle obj) {
auto static tp_symn = py::type::of<c10::SymIntNodeImpl>();
if (py::isinstance(obj, tp_symn)) {
TORCH_CHECK(
!jit::tracer::isTracing(), "JIT tracing of SymInts isn't supported!");
return true;
}
return false;
}
inline bool is_symfloat_node(py::handle obj) {
auto static tp_symn = py::type::of<c10::SymFloatNodeImpl>();
if (py::isinstance(obj, tp_symn)) {
TORCH_CHECK(
!jit::tracer::isTracing(), "JIT tracing of SymFloats isn't supported!");
return true;
}
return false;
}
} // namespace torch
namespace pybind11 {
namespace detail {
template <>
@ -109,8 +86,10 @@ struct type_caster<c10::SymInt> {
public:
PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt"));
bool load(py::handle src, bool) {
if (torch::is_symint_node(src)) {
value = src.cast<c10::SymIntNodeImpl*>()->toSymInt();
if (torch::is_symint(src)) {
value = c10::SymInt(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
src.attr("node"))));
return true;
}
@ -126,8 +105,15 @@ struct type_caster<c10::SymInt> {
c10::SymInt si,
return_value_policy /* policy */,
handle /* parent */) {
return si.is_symbolic() ? py::cast(si.toSymIntNodeImpl()).release()
: py::cast(si.expect_int()).release();
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_int_unchecked()).release();
}
}
};
@ -136,8 +122,10 @@ struct type_caster<c10::SymFloat> {
public:
PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat"));
bool load(py::handle src, bool) {
if (torch::is_symfloat_node(src)) {
value = src.cast<c10::SymFloatNodeImpl*>()->toSymFloat();
if (torch::is_symfloat(src)) {
value = c10::SymFloat(static_cast<c10::SymNode>(
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(
src.attr("node"))));
return true;
}
@ -153,8 +141,15 @@ struct type_caster<c10::SymFloat> {
c10::SymFloat si,
return_value_policy /* policy */,
handle /* parent */) {
return si.is_symbolic() ? py::cast(si.toSymFloatNodeImpl()).release()
: py::cast(si.expect_float()).release();
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symfloat_class()(py_node->getPyObj()).release();
} else {
return py::cast(si.as_float_unchecked()).release();
}
}
};
} // namespace detail
@ -167,8 +162,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) {
}
#endif
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
torch::is_symint_node(py::handle(obj)) ||
torch::is_symfloat_node(py::handle(obj));
torch::is_symint(py::handle(obj)) || torch::is_symfloat(py::handle(obj));
}
namespace torch {
@ -574,7 +568,7 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
inline PyObject* toPyObject(c10::SymInt symint) {
if (symint.is_symbolic()) {
auto r = py::cast(symint.toSymIntNodeImpl()).release().ptr();
auto r = py::cast(symint).release().ptr();
TORCH_INTERNAL_ASSERT(r);
return r;
} else {
@ -609,8 +603,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
}
if (size1 > 0 && torch::is_symint_node(py::handle(args[i]))) {
auto si = py::handle(args[i]).cast<c10::SymIntNodeImpl*>()->toSymInt();
if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
auto si = py::handle(args[i]).cast<c10::SymInt>();
return std::vector<c10::SymInt>(size1, si);
}
@ -652,9 +646,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
res.push_back(var.item<int64_t>());
} else {
try {
if (is_symint_node(py::handle(obj))) {
res.push_back(
py::handle(obj).cast<c10::SymIntNodeImpl*>()->toSymInt());
if (is_symint(py::handle(obj))) {
res.push_back(py::handle(obj).cast<c10::SymInt>());
} else {
res.push_back(c10::SymInt(THPUtils_unpackIndex(obj)));
}

View File

@ -0,0 +1,19 @@
#include <torch/csrc/utils/python_symnode.h>
namespace torch {
py::handle get_symint_class() {
// NB: leak
static py::handle symint_class =
py::object(py::module::import("torch").attr("SymInt")).release();
return symint_class;
}
py::handle get_symfloat_class() {
// NB: leak
static py::handle symfloat_class =
py::object(py::module::import("torch").attr("SymFloat")).release();
return symfloat_class;
}
} // namespace torch

View File

@ -0,0 +1,182 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
TORCH_PYTHON_API py::handle get_symint_class();
TORCH_PYTHON_API py::handle get_symfloat_class();
// NB: These functions must not be called too early, otherwise torch not setup.
// Alternate design is to have torch "register" the object to us
inline bool is_symint(py::handle obj) {
return py::isinstance(obj, get_symint_class());
}
inline bool is_symfloat(py::handle obj) {
return py::isinstance(obj, get_symfloat_class());
}
namespace impl {
// This c10::SymNodeImpl simply backends to a Python object that
// implements the API. The Python object is the source of truth,
// this is just an adapter so C++ calls can get to the object.
class PythonSymNodeImpl : public c10::SymNodeImpl {
public:
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
};
c10::SymNode wrap_int(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_int")(num);
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode wrap_float(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_float")(num);
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("bool_")().is(py::handle(Py_True));
}
bool is_int() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_int")().is(py::handle(Py_True));
}
bool is_float() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_float")().is(py::handle(Py_True));
}
int64_t guard_int(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
}
double guard_float(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_float")(file, line).cast<double>();
}
int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("int_")().cast<int64_t>();
}
std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("str")().cast<std::string>();
}
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode dispatch_common_(const char* fname) {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)();
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode add(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode sub(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode mul(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode truediv(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode pow(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode floordiv(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode mod(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode eq(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode gt(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode lt(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode le(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode ge(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode min(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode max(const c10::SymNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
c10::SymNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode floor() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode neg() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode clone() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode sym_int() override {
return dispatch_common_(__FUNCTION__);
}
c10::SymNode sym_float() override {
return dispatch_common_(__FUNCTION__);
}
py::handle getPyObj() {
return py::handle(pyobj_.get()->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
} // namespace impl
} // namespace torch

View File

@ -21,8 +21,9 @@ import operator
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
from torch._subclasses import FakeTensor
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
from torch.fx import Proxy
from torch import SymInt, SymFloat
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy", "py_sym_types"]
aten = torch.ops.aten
@ -55,27 +56,27 @@ def decompose(decomposition_table):
proxy_slot = object()
no_default = object()
py_sym_types = (
PySymInt,
PySymFloat,
)
py_sym_types = (SymInt, SymFloat)
def is_sym_node(node):
assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
def set_proxy_slot(obj, tracer, proxy):
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary())
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload]
assert isinstance(d, weakref.WeakKeyDictionary)
d[tracer] = proxy
def has_proxy_slot(obj, tracer):
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
return get_proxy_slot(obj, tracer, False, lambda _: True)
# the default argument is what to return if the slot is not set.
# the transform argument is handy if you need to extract a subfield from
# the successfully looked up result (but NOT the default.)
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
d = obj.__dict__.get(proxy_slot)
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
d = obj.__dict__.get(proxy_slot) # type: ignore[call-overload]
if not d:
if default is no_default:
raise KeyError(f"{obj} is not tracked with proxy for {tracer}")
@ -130,10 +131,8 @@ def track_tensor(tensor, proxy, *, constant, tracer):
def try_set_proxy_slot(outer_s, proxy_callable, *args):
assert callable(proxy_callable)
if isinstance(outer_s, SymInt):
inner_s = outer_s.get_pyobj()
assert isinstance(inner_s, py_sym_types)
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args))
inner_s = outer_s.node
set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, outer_s, *args))
# The basic idea is that we need to associate each tensor/SymInt
# with a Proxy. How do we setup this association? We just store
@ -198,7 +197,7 @@ class _ProxyTensor:
def fetch_sym_proxy(tracer):
def inner(e):
n = e.get_pyobj()
n = e.node
if n.constant is not None:
return n.constant
else:
@ -400,8 +399,8 @@ class PythonKeyTracer(Tracer):
return self.create_node('get_attr', qualname, (), {})
elif isinstance(a, (SymInt, SymFloat)):
assert a.get_pyobj().constant is not None
return a.get_pyobj().constant
assert a.node.constant is not None
return a.node.constant
return super().create_arg(a)
@ -432,7 +431,7 @@ def wrap_key(f, tensors, tracer):
)
out = pytree.tree_map_only(
(SymInt, SymFloat),
lambda t: get_proxy_slot(t.get_pyobj(), tracer)(),
lambda t: get_proxy_slot(t.node, tracer)(),
out
)
return out
@ -479,10 +478,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
return out
SymInt = torch.SymIntNode
SymFloat = torch.SymFloatNode
class ProxySymDispatchMode(SymDispatchMode):
def __init__(self, tracer):
super().__init__()
@ -501,10 +496,9 @@ class ProxySymDispatchMode(SymDispatchMode):
finally:
self.enable_tracing = old
def _compute_proxy(self, func, args, out):
def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat]):
n_args = tuple(
get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant
if isinstance(a, py_sym_types) else a
get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a
for a in args
)
@ -520,10 +514,11 @@ class ProxySymDispatchMode(SymDispatchMode):
return func(*args, **kwargs)
# Peephole optimize multiply by one
# NB: be careful not to trigger guards here!
if func == operator.mul:
if isinstance(args[1], (PySymInt, PySymFloat)) and args[1].constant == 1:
if isinstance(args[1], int) and args[1] == 1:
return args[0]
elif isinstance(args[0], (PySymInt, PySymFloat)) and args[0].constant == 1:
elif isinstance(args[0], int) and args[0] == 1:
return args[1]
# For speed, we assume there are no nested data structures
@ -535,7 +530,7 @@ class ProxySymDispatchMode(SymDispatchMode):
# Delays tracing out the proxies on this op until we actually need it
p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
set_proxy_slot(out, self.tracer, p_out_thunk)
set_proxy_slot(out.node, self.tracer, p_out_thunk)
return out

View File

@ -10,6 +10,7 @@ import traceback
import collections
import textwrap
from torch._subclasses.meta_utils import MetaConverter
from torch import SymInt, SymFloat
try:
import sympy # type: ignore[import]
@ -21,8 +22,8 @@ except ImportError:
aten = torch.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv",
"SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv"
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"SymDispatchMode", "sym_float", "FloorDiv", "guard_int", "wrap_node"
]
SYM_FUNCTION_MODE = None
@ -88,32 +89,38 @@ def _handle_sym_dispatch(func, args, kwargs):
finally:
SYM_FUNCTION_MODE = mode
def guard_int(a):
if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace
assert isinstance(a, int)
return a
def sym_float(a):
if hasattr(a, '__sym_float__'):
return a.__sym_float__()
elif isinstance(a, torch._C.SymFloatNode):
if isinstance(a, SymFloat):
return a
elif hasattr(a, '__sym_float__'):
return a.__sym_float__()
return float(a)
def sym_int(a):
if hasattr(a, '__sym_int__'):
return a.__sym_int__()
elif isinstance(a, torch._C.SymIntNode):
if isinstance(a, SymInt):
return a
elif hasattr(a, '__sym_int__'):
return a.__sym_int__()
return int(a)
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class PySymInt(object):
class SymNode:
"""
PySymInt objects are the primary "symbolic shape" objects that flow through
our program. They're what sit under FakeTensor, and contains our primary
implementation of symbolic shapes.
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(self, expr, shape_env, constant=None):
def __init__(self, expr, shape_env, pytype, constant=None):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
self.constant = constant
@property
@ -121,23 +128,49 @@ class PySymInt(object):
self._update_expr()
return self._expr
def wrap(self, num):
return PySymInt(sympy.Integer(num), self.shape_env, constant=num)
def clone(self):
return PySymInt(self.expr, self.shape_env, constant=self.constant)
def _update_expr(self):
self._expr = self.shape_env.replace(self._expr)
def __str__(self):
def to_node(self, num):
if isinstance(num, (SymInt, SymFloat)):
return num.node
elif isinstance(num, int):
return self.wrap_int(num)
elif isinstance(num, float):
return self.wrap_float(num)
else:
# NotImplementedError is important so that Python tries the
# other magic method
raise NotImplementedError(type(num))
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def wrap_int(self, num):
assert isinstance(num, int)
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
def wrap_float(self, num):
assert isinstance(num, float)
return SymNode(sympy.Integer(num), self.shape_env, float, constant=num)
def clone(self):
return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant)
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
return f"{self.expr}"
return self.str()
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
def __int__(self):
def int_(self):
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
# You can manually trigger a guard with this function
@ -146,28 +179,35 @@ class PySymInt(object):
# guard occurred
return int(self.shape_env.evaluate_expr(self.expr))
def __sym_float__(self):
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
return float(self.shape_env.evaluate_expr(self.expr))
def sym_float(self):
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(sym_float, (self,), {})
r = _handle_sym_dispatch(sym_float, (wrap_node(self),), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
# TODO: consider constant prop here
# TODO: wrapping the expr with sympy.Float doesn't seem to work, why
# not?
return PySymFloat(self.expr, self.shape_env)
return SymNode(self.expr, self.shape_env, float)
def __bool__(self):
def sym_int(self):
raise NotImplementedError("sym_int NYI")
"""
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(sym_int, (self,), {})
# TODO: consider constant prop here
# XXX: need to cast float to int in sympy; math.floor is wrong
# because negatives round to zero
return SymNode(self.expr, self.shape_env, int)
"""
def bool_(self):
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
class PySymFloat:
def __init__(self, expr, shape_env, constant=None):
self.expr = expr
self.shape_env = shape_env
self.constant = constant
def wrap(self, num):
return PySymFloat(sympy.Float(num), self.shape_env, constant=num)
def __str__(self):
return f"{self.expr}"
if HAS_SYMPY:
class FloorDiv(sympy.Function):
@ -238,32 +278,45 @@ unary_magic_methods = {
float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"}
def _make_magic(method, func, py_type):
def wrap_node(x):
if not isinstance(x, SymNode):
return x
if x.constant is not None:
return x.constant
if x.pytype is int:
return SymInt(x)
elif x.pytype is float:
return SymFloat(x)
else:
raise AssertionError(f"unrecognized return type {x.pytype}")
def _make_node_magic(method, func):
func = lru_cache(256)(func)
def magic_impl(self, other):
def binary_magic_impl(self, other):
if method in ["min", "max"]:
op = getattr(builtins, method)
else:
op = getattr(operator, method)
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(op, (self, other), {})
if isinstance(other, py_type):
other_expr = other.expr
else:
assert isinstance(other, sympy.Expr)
other_expr = other
r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
assert isinstance(other, SymNode)
other_expr = other.expr
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
other_expr = self.shape_env.replace(other_expr)
out = func(expr, other_expr)
out = sympy.expand(out)
if method in ["truediv"]:
return PySymFloat(out, self.shape_env)
pytype = float
else:
# TODO: relational operators actually technically return a
# PySymBool, this is a type error
return py_type(out, self.shape_env)
pytype = self.pytype
# TODO: relational operators actually technically return a
# PySymBool, this is a type error
return SymNode(out, self.shape_env, pytype)
def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
@ -271,33 +324,55 @@ def _make_magic(method, func, py_type):
op = getattr(math, method)
else:
op = getattr(operator, method)
return _handle_sym_dispatch(op, (self,), {})
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
out = func(expr)
out = sympy.expand(out)
if method in ["ceil", "floor"]:
return PySymInt(out, self.shape_env)
pytype = int
else:
return py_type(out, self.shape_env)
pytype = self.pytype
return SymNode(out, self.shape_env, pytype)
# this should be wrapped transparently into torch.SymIntNode
if method in unary_magic_methods:
setattr(py_type, method, unary_magic_impl)
setattr(py_type, f"__{method}__", unary_magic_impl)
setattr(SymNode, method, unary_magic_impl)
else:
setattr(py_type, method, magic_impl)
setattr(py_type, f"__{method}__", magic_impl)
if method in reflectable_magic_methods:
setattr(py_type, f"__r{method}__", magic_impl)
setattr(SymNode, method, binary_magic_impl)
for method, func in magic_methods.items():
_make_magic(method, func, PySymInt)
_make_node_magic(method, func)
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
def unary_magic_impl(self):
return wrap_node(getattr(self.node, method)())
def binary_magic_impl(self, other):
return wrap_node(getattr(self.node, method)(self.node.to_node(other)))
def rbinary_magic_impl(self, other):
return wrap_node(getattr(self.node.to_node(other), method)(self.node))
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
for method, func in magic_methods.items():
_make_user_magic(method, SymInt)
for method, func in magic_methods.items():
if method not in float_magic_methods:
continue
_make_magic(method, func, PySymFloat)
_make_user_magic(method, SymFloat)
del method
del func
@ -390,9 +465,7 @@ class ShapeEnv(object):
return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type]
def create_symintnode(self, expr: "sympy.Expr"):
py_sym_int = PySymInt(expr, self)
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
return cpp_sym_int
return SymInt(SymNode(expr, self, int))
def create_symbol(self, val: int) -> "sympy.Expr":
if not HAS_SYMPY:

View File

@ -498,7 +498,7 @@ class CodeGen(object):
if isinstance(meta_val, FakeTensor):
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
elif isinstance(meta_val, py_sym_types):
maybe_type_annotation = f': Sym({meta_val.expr})'
maybe_type_annotation = f': Sym({meta_val})'
elif isinstance(meta_val, TensorMetadata):
maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'