Add SymInt to Scalar (#84958)

This is by no means comprehensive, but adds initial support for SymInt as a Scalar.

Things that don't work yet but need to:
- for some reason `torch.add(tensor, sym_int)` got matched to the `add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor` schema
- `x + sym_int` failed bc we tried to turn `x` into a sym int:
```
              "__radd__",
              [](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
                auto snb = toSymIntNode(a, b);
                return a->add(snb);
              })
 ```
- Many more things I'm sure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84958
Approved by: https://github.com/ezyang
This commit is contained in:
Edward Z. Yang
2022-09-25 12:17:01 -07:00
committed by PyTorch MergeBot
parent 33404436aa
commit 9c036aa112
11 changed files with 289 additions and 34 deletions

View File

@ -23,6 +23,12 @@ std::ostream& operator<<(std::ostream & out, Scalar s) {
if (s.isBoolean()) {
return out << (s.toBool() ? "true" : "false");
}
if (s.isSymInt()) {
return out << (s.toSymInt());
}
if (s.isSymFloat()) {
return out << (s.toSymFloat());
}
if (s.isIntegral(false)) {
return out << s.toLong();
}

View File

@ -777,6 +777,12 @@ public:
} else if (s.isBoolean()) {
tag = Tag::Bool;
payload.u.as_bool = s.toBool();
} else if (s.isSymInt()) {
tag = Tag::SymInt;
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release();
} else if (s.isSymFloat()) {
tag = Tag::SymFloat;
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release();
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(s.isIntegral(false), "Unknown type in Scalar");
tag = Tag::Int;
@ -785,7 +791,7 @@ public:
}
bool isScalar() const {
return isDouble() || isInt() || isComplexDouble() || isBool();
return isDouble() || isInt() || isComplexDouble() || isBool() || isSymInt() || isSymFloat();
}
at::Scalar toScalar() const {
@ -797,6 +803,10 @@ public:
return toComplexDouble();
else if (isBool())
return toBool();
else if (isSymInt())
return toSymInt();
else if (isSymFloat())
return toSymFloat();
throw std::runtime_error("IValue is not a Scalar");
}
@ -1144,6 +1154,7 @@ public:
}
union Payload {
// [TriviallyCopyablePayload]
// We use a nested union here so that we can make the copy easy
// and efficient in the non-tensor (i.e., trivially copyable)
// case. Specifically, we do not have to do a switch-on-tag to

View File

@ -2,6 +2,7 @@
#include <iostream>
#include <random>
#include <c10/core/SymInt.h>
// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
#ifndef _USE_MATH_DEFINES
@ -12,6 +13,18 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
// We intentionally test self assignment/move in this file, suppress warnings
// on them
#ifndef _MSC_VER
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Wunknown-warning-option"
#pragma GCC diagnostic ignored "-Wself-move"
#endif
#ifdef __clang__
#pragma clang diagnostic ignored "-Wself-assign-overloaded"
#endif
using std::cout;
using namespace at;
@ -179,4 +192,36 @@ TEST(TestScalar, TestFormatting) {
ASSERT_EQ("false", format(Scalar(false)));
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<double>(2.0, 3.1))));
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
}
TEST(TestSymInt, Basic) {
Scalar foo;
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
foo = Scalar(a_impl->toSymInt());
ASSERT_EQ(a_impl.use_count(), 2);
Scalar bar{foo};
ASSERT_EQ(a_impl.use_count(), 3);
auto baz = bar;
ASSERT_EQ(a_impl.use_count(), 4);
auto foo2 = std::move(bar);
ASSERT_EQ(a_impl.use_count(), 4);
ASSERT_TRUE(foo2.isSymInt());
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
ASSERT_TRUE(bar.isIntegral(false));
foo2 = SymInt(4);
ASSERT_FALSE(foo2.isSymInt());
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
foo2 = foo2;
ASSERT_FALSE(foo2.isSymInt());
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
ASSERT_EQ(a_impl.use_count(), 3);
ASSERT_THROW(foo.to<double>(), c10::Error);
Scalar int_s = 3;
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
}

View File

@ -7,16 +7,21 @@ Scalar Scalar::operator-() const {
!isBoolean(),
"torch boolean negative, the `-` operator, is not supported.");
if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI negate symbolic float");
return Scalar(-v.d);
} else if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return Scalar(-v.z);
} else {
} else if (isIntegral(false)) {
TORCH_CHECK(!isSymbolic(), "NYI negate symbolic int");
return Scalar(-v.i);
}
TORCH_INTERNAL_ASSERT(false, "unknown ivalue tag ", static_cast<int>(tag));
}
Scalar Scalar::conj() const {
if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return Scalar(std::conj(v.z));
} else {
return *this;
@ -25,12 +30,16 @@ Scalar Scalar::conj() const {
Scalar Scalar::log() const {
if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return std::log(v.z);
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI log symbolic float");
return std::log(v.d);
} else {
} else if (isIntegral(false)) {
TORCH_CHECK(!isSymbolic(), "NYI log symbolic int");
return std::log(v.i);
}
TORCH_INTERNAL_ASSERT(false, "unknown ivalue tag ", static_cast<int>(tag));
}
} // namespace c10

View File

@ -9,10 +9,13 @@
#include <c10/core/OptionalRef.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#include <c10/util/intrusive_ptr.h>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
@ -33,6 +36,17 @@ class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}
void destroy() {
if (Tag::HAS_si == tag || Tag::HAS_sd == tag) {
raw::intrusive_ptr::decref(v.p);
v.p = nullptr;
}
}
~Scalar() {
destroy();
}
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {}
@ -61,35 +75,63 @@ class C10_API Scalar {
} \
if (Tag::HAS_b == tag) { \
return checked_convert<type, bool>(v.i, #type); \
} else { \
} else if (Tag::HAS_i == tag) { \
return checked_convert<type, int64_t>(v.i, #type); \
} else if (Tag::HAS_si == tag) { \
TORCH_CHECK(false, "tried to get " #name " out of SymInt") \
} else if (Tag::HAS_sd == tag) { \
TORCH_CHECK(false, "tried to get " #name " out of SymFloat") \
} \
TORCH_CHECK(false) \
}
// TODO: Support ComplexHalf accessor
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR)
#undef DEFINE_ACCESSOR
SymInt toSymInt() const {
if (Tag::HAS_si == tag) {
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy(
static_cast<SymIntNodeImpl*>(v.p)));
} else {
return toLong();
}
}
SymFloat toSymFloat() const {
if (Tag::HAS_sd == tag) {
return c10::SymFloat::toSymFloat(
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy(
static_cast<SymFloatNodeImpl*>(v.p)));
} else {
return toLong();
}
}
// also support scalar.to<int64_t>();
// Deleted for unsupported types, but specialized below for supported types
template <typename T>
T to() const = delete;
// audit uses of data_ptr
const void* data_ptr() const {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return static_cast<const void*>(&v);
}
#undef DEFINE_ACCESSOR
bool isFloatingPoint() const {
return Tag::HAS_d == tag;
return Tag::HAS_d == tag || Tag::HAS_sd == tag;
}
C10_DEPRECATED_MESSAGE(
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
bool isIntegral() const {
return Tag::HAS_i == tag;
return Tag::HAS_i == tag || Tag::HAS_si == tag;
}
bool isIntegral(bool includeBool) const {
return Tag::HAS_i == tag || (includeBool && isBoolean());
return Tag::HAS_i == tag || Tag::HAS_si == tag ||
(includeBool && isBoolean());
}
bool isComplex() const {
@ -99,6 +141,37 @@ class C10_API Scalar {
return Tag::HAS_b == tag;
}
// you probably don't actually want these; they're mostly for testing
bool isSymInt() const {
return Tag::HAS_si == tag;
}
bool isSymFloat() const {
return Tag::HAS_sd == tag;
}
bool isSymbolic() const {
return Tag::HAS_si == tag || Tag::HAS_sd == tag;
}
C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) {
if (&other == this) {
return *this;
}
destroy();
moveFrom(std::move(other));
return *this;
}
C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) {
if (&other == this) {
return *this;
}
*this = Scalar(other);
return *this;
}
Scalar operator-() const;
Scalar conj() const;
Scalar log() const;
@ -108,15 +181,21 @@ class C10_API Scalar {
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
bool equal(T num) const {
if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
auto val = v.z;
return (val.real() == num) && (val.imag() == T());
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
return v.d == num;
} else if (isIntegral(/*includeBool=*/false)) {
TORCH_CHECK(!isSymbolic(), "NYI SymInt equality");
return v.i == num;
} else {
} else if (isBoolean()) {
// boolean scalar does not equal to a non boolean value
TORCH_INTERNAL_ASSERT(!isSymbolic());
return false;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
@ -125,19 +204,26 @@ class C10_API Scalar {
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
bool equal(T num) const {
if (isComplex()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return v.z == num;
} else if (isFloatingPoint()) {
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
return (v.d == num.real()) && (num.imag() == T());
} else if (isIntegral(/*includeBool=*/false)) {
TORCH_CHECK(!isSymbolic(), "NYI SymInt equality");
return (v.i == num.real()) && (num.imag() == T());
} else {
} else if (isBoolean()) {
// boolean scalar does not equal to a non boolean value
TORCH_INTERNAL_ASSERT(!isSymbolic());
return false;
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
bool equal(bool num) const {
if (isBoolean()) {
TORCH_INTERNAL_ASSERT(!isSymbolic());
return static_cast<bool>(v.i) == num;
} else {
return false;
@ -158,7 +244,62 @@ class C10_API Scalar {
}
}
Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) {
moveFrom(std::move(rhs));
}
Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {
if (isSymbolic()) {
c10::raw::intrusive_ptr::incref(v.p);
}
}
Scalar(c10::SymInt si) {
if (si.is_symbolic()) {
tag = Tag::HAS_si;
v.p = std::move(si).release();
} else {
tag = Tag::HAS_i;
v.i = si.as_int_unchecked();
}
}
Scalar(c10::SymFloat sd) {
if (sd.is_symbolic()) {
tag = Tag::HAS_sd;
v.p = std::move(sd).release();
} else {
tag = Tag::HAS_d;
v.d = sd.as_float_unchecked();
}
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC
private:
enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b, HAS_sd, HAS_si };
// NB: assumes that self has already been cleared
C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept {
v = rhs.v;
tag = rhs.tag;
if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd) {
// Move out of scalar
rhs.tag = Tag::HAS_i;
rhs.v.i = 0;
}
}
Tag tag;
union v_t {
double d;
int64_t i;
c10::complex<double> z;
c10::intrusive_ptr_target* p;
v_t() {} // default constructor
} v;
template <
typename T,
typename std::enable_if<
@ -183,18 +324,6 @@ class C10_API Scalar {
Scalar(T vv, bool) : tag(Tag::HAS_z) {
v.z = convert<decltype(v.z), T>(vv);
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC
enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b };
Tag tag;
union v_t {
double d;
int64_t i;
c10::complex<double> z;
v_t() {} // default constructor
} v;
};
using OptionalScalarRef = c10::OptionalRef<Scalar>;

View File

@ -22,6 +22,10 @@ class C10_API SymFloat {
return ptr_.get();
}
SymFloatNodeImpl* release() && {
return std::move(ptr_).release();
}
SymFloatNode toSymFloatNodeImpl() const;
static c10::SymFloat toSymFloat(SymFloatNode sin);

View File

@ -102,8 +102,19 @@ class C10_API SymInt {
SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal
}
}
SymIntNodeImpl* release() && {
TORCH_INTERNAL_ASSERT(is_symbolic());
auto* r = toSymIntNodeImplUnowned();
data_ = 0; // transfer ownership
return r;
}
#else
void release_() {}
SymIntNodeImpl* release() && {
TORCH_INTERNAL_ASSERT(false);
}
#endif
SymIntNode toSymIntNodeImpl() const;

View File

@ -11,6 +11,7 @@ import operator
import itertools
from torch.utils._pytree import tree_map
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
from torch.utils._python_dispatch import TorchDispatchMode
aten = torch.ops.aten
@ -321,6 +322,27 @@ class TestPySymInt(TestCase):
a0 = shape_env.create_symint("a0", 2)
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
@skipIfNoSympy
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = shape_env.create_symint("a0", 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
sym_int_encountered = kwargs["alpha"] is a0
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
if __name__ == '__main__':
run_tests()

View File

@ -741,13 +741,15 @@ auto FunctionParameter::check(
return true;
}
if (allow_numbers_as_tensors) {
return THPUtils_checkScalar(obj) ||
torch::is_symint_node(py::handle(obj)) ||
torch::is_symfloat_node(py::handle(obj));
return THPUtils_checkScalar(obj);
}
return false;
}
case ParameterType::SCALAR:
if (THPUtils_checkScalar(obj)) {
return true;
}
// fallthrough
case ParameterType::COMPLEX:
if (PyComplex_Check(obj)) {
return true;
@ -1498,6 +1500,9 @@ at::Tensor PythonArgs::tensor_slow(int i) {
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
} else if (THPUtils_checkDouble(obj)) {
scalar = at::Scalar(THPUtils_unpackDouble(obj));
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
// because although Scalar supports SymInt/SymFloat, the subsequent
// conversion to Tensor does not. Instead, do it out of band.
} else if (torch::is_symint_node(py::handle(obj))) {
save_symint = true;
// This scalar value doesn't matter, it shouldn't ever actually
@ -1560,6 +1565,15 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
if (PyComplex_Check(arg)) {
return at::Scalar(THPUtils_unpackComplexDouble(arg));
}
if (torch::is_symint_node(arg)) {
return at::Scalar(py::cast<c10::SymInt>(arg));
}
if (torch::is_symfloat_node(arg)) {
return at::Scalar(py::cast<c10::SymFloat>(arg));
}
return at::Scalar(THPUtils_unpackDouble(arg));
}

View File

@ -79,6 +79,7 @@
#include <vector>
namespace torch {
inline bool is_symint_node(py::handle obj) {
auto static tp_symn = py::type::of<c10::SymIntNodeImpl>();
if (py::isinstance(obj, tp_symn)) {
@ -98,6 +99,7 @@ inline bool is_symfloat_node(py::handle obj) {
}
return false;
}
} // namespace torch
namespace pybind11 {
@ -158,6 +160,17 @@ struct type_caster<c10::SymFloat> {
} // namespace detail
} // namespace pybind11
inline bool THPUtils_checkScalar(PyObject* obj) {
#ifdef USE_NUMPY
if (torch::utils::is_numpy_scalar(obj)) {
return true;
}
#endif
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
torch::is_symint_node(py::handle(obj)) ||
torch::is_symfloat_node(py::handle(obj));
}
namespace torch {
bool should_allow_numbers_as_tensors(const std::string& name);

View File

@ -139,15 +139,6 @@ inline bool THPUtils_checkDouble(PyObject* obj) {
return PyFloat_Check(obj) || PyLong_Check(obj);
}
inline bool THPUtils_checkScalar(PyObject* obj) {
#ifdef USE_NUMPY
if (torch::utils::is_numpy_scalar(obj)) {
return true;
}
#endif
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj);
}
inline double THPUtils_unpackDouble(PyObject* obj) {
if (PyFloat_Check(obj)) {
return PyFloat_AS_DOUBLE(obj);