mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add SymInt to Scalar (#84958)
This is by no means comprehensive, but adds initial support for SymInt as a Scalar. Things that don't work yet but need to: - for some reason `torch.add(tensor, sym_int)` got matched to the `add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor` schema - `x + sym_int` failed bc we tried to turn `x` into a sym int: ``` "__radd__", [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { auto snb = toSymIntNode(a, b); return a->add(snb); }) ``` - Many more things I'm sure Pull Request resolved: https://github.com/pytorch/pytorch/pull/84958 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
33404436aa
commit
9c036aa112
@ -23,6 +23,12 @@ std::ostream& operator<<(std::ostream & out, Scalar s) {
|
||||
if (s.isBoolean()) {
|
||||
return out << (s.toBool() ? "true" : "false");
|
||||
}
|
||||
if (s.isSymInt()) {
|
||||
return out << (s.toSymInt());
|
||||
}
|
||||
if (s.isSymFloat()) {
|
||||
return out << (s.toSymFloat());
|
||||
}
|
||||
if (s.isIntegral(false)) {
|
||||
return out << s.toLong();
|
||||
}
|
||||
|
@ -777,6 +777,12 @@ public:
|
||||
} else if (s.isBoolean()) {
|
||||
tag = Tag::Bool;
|
||||
payload.u.as_bool = s.toBool();
|
||||
} else if (s.isSymInt()) {
|
||||
tag = Tag::SymInt;
|
||||
payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release();
|
||||
} else if (s.isSymFloat()) {
|
||||
tag = Tag::SymFloat;
|
||||
payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release();
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(s.isIntegral(false), "Unknown type in Scalar");
|
||||
tag = Tag::Int;
|
||||
@ -785,7 +791,7 @@ public:
|
||||
}
|
||||
|
||||
bool isScalar() const {
|
||||
return isDouble() || isInt() || isComplexDouble() || isBool();
|
||||
return isDouble() || isInt() || isComplexDouble() || isBool() || isSymInt() || isSymFloat();
|
||||
}
|
||||
|
||||
at::Scalar toScalar() const {
|
||||
@ -797,6 +803,10 @@ public:
|
||||
return toComplexDouble();
|
||||
else if (isBool())
|
||||
return toBool();
|
||||
else if (isSymInt())
|
||||
return toSymInt();
|
||||
else if (isSymFloat())
|
||||
return toSymFloat();
|
||||
throw std::runtime_error("IValue is not a Scalar");
|
||||
}
|
||||
|
||||
@ -1144,6 +1154,7 @@ public:
|
||||
}
|
||||
|
||||
union Payload {
|
||||
// [TriviallyCopyablePayload]
|
||||
// We use a nested union here so that we can make the copy easy
|
||||
// and efficient in the non-tensor (i.e., trivially copyable)
|
||||
// case. Specifically, we do not have to do a switch-on-tag to
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <c10/core/SymInt.h>
|
||||
// define constants like M_PI and C keywords for MSVC
|
||||
#ifdef _MSC_VER
|
||||
#ifndef _USE_MATH_DEFINES
|
||||
@ -12,6 +13,18 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
// We intentionally test self assignment/move in this file, suppress warnings
|
||||
// on them
|
||||
#ifndef _MSC_VER
|
||||
#pragma GCC diagnostic ignored "-Wpragmas"
|
||||
#pragma GCC diagnostic ignored "-Wunknown-warning-option"
|
||||
#pragma GCC diagnostic ignored "-Wself-move"
|
||||
#endif
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic ignored "-Wself-assign-overloaded"
|
||||
#endif
|
||||
|
||||
using std::cout;
|
||||
using namespace at;
|
||||
|
||||
@ -179,4 +192,36 @@ TEST(TestScalar, TestFormatting) {
|
||||
ASSERT_EQ("false", format(Scalar(false)));
|
||||
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<double>(2.0, 3.1))));
|
||||
ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
|
||||
ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
|
||||
}
|
||||
|
||||
TEST(TestSymInt, Basic) {
|
||||
Scalar foo;
|
||||
auto a_impl = c10::make_intrusive<c10::SymIntNodeImpl>();
|
||||
foo = Scalar(a_impl->toSymInt());
|
||||
ASSERT_EQ(a_impl.use_count(), 2);
|
||||
Scalar bar{foo};
|
||||
ASSERT_EQ(a_impl.use_count(), 3);
|
||||
auto baz = bar;
|
||||
ASSERT_EQ(a_impl.use_count(), 4);
|
||||
auto foo2 = std::move(bar);
|
||||
ASSERT_EQ(a_impl.use_count(), 4);
|
||||
ASSERT_TRUE(foo2.isSymInt());
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
|
||||
ASSERT_TRUE(bar.isIntegral(false));
|
||||
foo2 = SymInt(4);
|
||||
ASSERT_FALSE(foo2.isSymInt());
|
||||
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
|
||||
// NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded)
|
||||
foo2 = foo2;
|
||||
ASSERT_FALSE(foo2.isSymInt());
|
||||
ASSERT_EQ(foo2.toSymInt().expect_int(), 4);
|
||||
|
||||
ASSERT_EQ(a_impl.use_count(), 3);
|
||||
|
||||
ASSERT_THROW(foo.to<double>(), c10::Error);
|
||||
|
||||
Scalar int_s = 3;
|
||||
TORCH_CHECK(int_s.toSymInt().expect_int(), 3);
|
||||
|
||||
}
|
||||
|
@ -7,16 +7,21 @@ Scalar Scalar::operator-() const {
|
||||
!isBoolean(),
|
||||
"torch boolean negative, the `-` operator, is not supported.");
|
||||
if (isFloatingPoint()) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI negate symbolic float");
|
||||
return Scalar(-v.d);
|
||||
} else if (isComplex()) {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return Scalar(-v.z);
|
||||
} else {
|
||||
} else if (isIntegral(false)) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI negate symbolic int");
|
||||
return Scalar(-v.i);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false, "unknown ivalue tag ", static_cast<int>(tag));
|
||||
}
|
||||
|
||||
Scalar Scalar::conj() const {
|
||||
if (isComplex()) {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return Scalar(std::conj(v.z));
|
||||
} else {
|
||||
return *this;
|
||||
@ -25,12 +30,16 @@ Scalar Scalar::conj() const {
|
||||
|
||||
Scalar Scalar::log() const {
|
||||
if (isComplex()) {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return std::log(v.z);
|
||||
} else if (isFloatingPoint()) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI log symbolic float");
|
||||
return std::log(v.d);
|
||||
} else {
|
||||
} else if (isIntegral(false)) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI log symbolic int");
|
||||
return std::log(v.i);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(false, "unknown ivalue tag ", static_cast<int>(tag));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -9,10 +9,13 @@
|
||||
|
||||
#include <c10/core/OptionalRef.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/TypeCast.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
@ -33,6 +36,17 @@ class C10_API Scalar {
|
||||
public:
|
||||
Scalar() : Scalar(int64_t(0)) {}
|
||||
|
||||
void destroy() {
|
||||
if (Tag::HAS_si == tag || Tag::HAS_sd == tag) {
|
||||
raw::intrusive_ptr::decref(v.p);
|
||||
v.p = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
~Scalar() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
||||
Scalar(type vv) : Scalar(vv, true) {}
|
||||
|
||||
@ -61,35 +75,63 @@ class C10_API Scalar {
|
||||
} \
|
||||
if (Tag::HAS_b == tag) { \
|
||||
return checked_convert<type, bool>(v.i, #type); \
|
||||
} else { \
|
||||
} else if (Tag::HAS_i == tag) { \
|
||||
return checked_convert<type, int64_t>(v.i, #type); \
|
||||
} else if (Tag::HAS_si == tag) { \
|
||||
TORCH_CHECK(false, "tried to get " #name " out of SymInt") \
|
||||
} else if (Tag::HAS_sd == tag) { \
|
||||
TORCH_CHECK(false, "tried to get " #name " out of SymFloat") \
|
||||
} \
|
||||
TORCH_CHECK(false) \
|
||||
}
|
||||
|
||||
// TODO: Support ComplexHalf accessor
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR)
|
||||
|
||||
#undef DEFINE_ACCESSOR
|
||||
|
||||
SymInt toSymInt() const {
|
||||
if (Tag::HAS_si == tag) {
|
||||
return c10::SymInt::toSymInt(intrusive_ptr<SymIntNodeImpl>::reclaim_copy(
|
||||
static_cast<SymIntNodeImpl*>(v.p)));
|
||||
} else {
|
||||
return toLong();
|
||||
}
|
||||
}
|
||||
|
||||
SymFloat toSymFloat() const {
|
||||
if (Tag::HAS_sd == tag) {
|
||||
return c10::SymFloat::toSymFloat(
|
||||
intrusive_ptr<SymFloatNodeImpl>::reclaim_copy(
|
||||
static_cast<SymFloatNodeImpl*>(v.p)));
|
||||
} else {
|
||||
return toLong();
|
||||
}
|
||||
}
|
||||
|
||||
// also support scalar.to<int64_t>();
|
||||
// Deleted for unsupported types, but specialized below for supported types
|
||||
template <typename T>
|
||||
T to() const = delete;
|
||||
|
||||
// audit uses of data_ptr
|
||||
const void* data_ptr() const {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return static_cast<const void*>(&v);
|
||||
}
|
||||
|
||||
#undef DEFINE_ACCESSOR
|
||||
bool isFloatingPoint() const {
|
||||
return Tag::HAS_d == tag;
|
||||
return Tag::HAS_d == tag || Tag::HAS_sd == tag;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")
|
||||
bool isIntegral() const {
|
||||
return Tag::HAS_i == tag;
|
||||
return Tag::HAS_i == tag || Tag::HAS_si == tag;
|
||||
}
|
||||
bool isIntegral(bool includeBool) const {
|
||||
return Tag::HAS_i == tag || (includeBool && isBoolean());
|
||||
return Tag::HAS_i == tag || Tag::HAS_si == tag ||
|
||||
(includeBool && isBoolean());
|
||||
}
|
||||
|
||||
bool isComplex() const {
|
||||
@ -99,6 +141,37 @@ class C10_API Scalar {
|
||||
return Tag::HAS_b == tag;
|
||||
}
|
||||
|
||||
// you probably don't actually want these; they're mostly for testing
|
||||
bool isSymInt() const {
|
||||
return Tag::HAS_si == tag;
|
||||
}
|
||||
bool isSymFloat() const {
|
||||
return Tag::HAS_sd == tag;
|
||||
}
|
||||
|
||||
bool isSymbolic() const {
|
||||
return Tag::HAS_si == tag || Tag::HAS_sd == tag;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) {
|
||||
if (&other == this) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
destroy();
|
||||
moveFrom(std::move(other));
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) {
|
||||
if (&other == this) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
*this = Scalar(other);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Scalar operator-() const;
|
||||
Scalar conj() const;
|
||||
Scalar log() const;
|
||||
@ -108,15 +181,21 @@ class C10_API Scalar {
|
||||
typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
|
||||
bool equal(T num) const {
|
||||
if (isComplex()) {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
auto val = v.z;
|
||||
return (val.real() == num) && (val.imag() == T());
|
||||
} else if (isFloatingPoint()) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
|
||||
return v.d == num;
|
||||
} else if (isIntegral(/*includeBool=*/false)) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI SymInt equality");
|
||||
return v.i == num;
|
||||
} else {
|
||||
} else if (isBoolean()) {
|
||||
// boolean scalar does not equal to a non boolean value
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return false;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,19 +204,26 @@ class C10_API Scalar {
|
||||
typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
|
||||
bool equal(T num) const {
|
||||
if (isComplex()) {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return v.z == num;
|
||||
} else if (isFloatingPoint()) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI SymFloat equality");
|
||||
return (v.d == num.real()) && (num.imag() == T());
|
||||
} else if (isIntegral(/*includeBool=*/false)) {
|
||||
TORCH_CHECK(!isSymbolic(), "NYI SymInt equality");
|
||||
return (v.i == num.real()) && (num.imag() == T());
|
||||
} else {
|
||||
} else if (isBoolean()) {
|
||||
// boolean scalar does not equal to a non boolean value
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return false;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
bool equal(bool num) const {
|
||||
if (isBoolean()) {
|
||||
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
||||
return static_cast<bool>(v.i) == num;
|
||||
} else {
|
||||
return false;
|
||||
@ -158,7 +244,62 @@ class C10_API Scalar {
|
||||
}
|
||||
}
|
||||
|
||||
Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) {
|
||||
moveFrom(std::move(rhs));
|
||||
}
|
||||
|
||||
Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {
|
||||
if (isSymbolic()) {
|
||||
c10::raw::intrusive_ptr::incref(v.p);
|
||||
}
|
||||
}
|
||||
|
||||
Scalar(c10::SymInt si) {
|
||||
if (si.is_symbolic()) {
|
||||
tag = Tag::HAS_si;
|
||||
v.p = std::move(si).release();
|
||||
} else {
|
||||
tag = Tag::HAS_i;
|
||||
v.i = si.as_int_unchecked();
|
||||
}
|
||||
}
|
||||
|
||||
Scalar(c10::SymFloat sd) {
|
||||
if (sd.is_symbolic()) {
|
||||
tag = Tag::HAS_sd;
|
||||
v.p = std::move(sd).release();
|
||||
} else {
|
||||
tag = Tag::HAS_d;
|
||||
v.d = sd.as_float_unchecked();
|
||||
}
|
||||
}
|
||||
|
||||
// We can't set v in the initializer list using the
|
||||
// syntax v{ .member = ... } because it doesn't work on MSVC
|
||||
private:
|
||||
enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b, HAS_sd, HAS_si };
|
||||
|
||||
// NB: assumes that self has already been cleared
|
||||
C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept {
|
||||
v = rhs.v;
|
||||
tag = rhs.tag;
|
||||
if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd) {
|
||||
// Move out of scalar
|
||||
rhs.tag = Tag::HAS_i;
|
||||
rhs.v.i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Tag tag;
|
||||
|
||||
union v_t {
|
||||
double d;
|
||||
int64_t i;
|
||||
c10::complex<double> z;
|
||||
c10::intrusive_ptr_target* p;
|
||||
v_t() {} // default constructor
|
||||
} v;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<
|
||||
@ -183,18 +324,6 @@ class C10_API Scalar {
|
||||
Scalar(T vv, bool) : tag(Tag::HAS_z) {
|
||||
v.z = convert<decltype(v.z), T>(vv);
|
||||
}
|
||||
|
||||
// We can't set v in the initializer list using the
|
||||
// syntax v{ .member = ... } because it doesn't work on MSVC
|
||||
|
||||
enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b };
|
||||
Tag tag;
|
||||
union v_t {
|
||||
double d;
|
||||
int64_t i;
|
||||
c10::complex<double> z;
|
||||
v_t() {} // default constructor
|
||||
} v;
|
||||
};
|
||||
|
||||
using OptionalScalarRef = c10::OptionalRef<Scalar>;
|
||||
|
@ -22,6 +22,10 @@ class C10_API SymFloat {
|
||||
return ptr_.get();
|
||||
}
|
||||
|
||||
SymFloatNodeImpl* release() && {
|
||||
return std::move(ptr_).release();
|
||||
}
|
||||
|
||||
SymFloatNode toSymFloatNodeImpl() const;
|
||||
static c10::SymFloat toSymFloat(SymFloatNode sin);
|
||||
|
||||
|
@ -102,8 +102,19 @@ class C10_API SymInt {
|
||||
SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal
|
||||
}
|
||||
}
|
||||
|
||||
SymIntNodeImpl* release() && {
|
||||
TORCH_INTERNAL_ASSERT(is_symbolic());
|
||||
auto* r = toSymIntNodeImplUnowned();
|
||||
data_ = 0; // transfer ownership
|
||||
return r;
|
||||
}
|
||||
#else
|
||||
void release_() {}
|
||||
|
||||
SymIntNodeImpl* release() && {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
}
|
||||
#endif
|
||||
|
||||
SymIntNode toSymIntNodeImpl() const;
|
||||
|
@ -11,6 +11,7 @@ import operator
|
||||
import itertools
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -321,6 +322,27 @@ class TestPySymInt(TestCase):
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_as_scalar(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = shape_env.create_symint("a0", 2)
|
||||
|
||||
sym_int_encountered = False
|
||||
|
||||
class TestSymInt(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
assert func == torch.ops.aten.add.Tensor
|
||||
|
||||
nonlocal sym_int_encountered
|
||||
sym_int_encountered = kwargs["alpha"] is a0
|
||||
kwargs["alpha"] = 0
|
||||
return func(*args)
|
||||
|
||||
x = torch.rand([4, 4])
|
||||
with TestSymInt():
|
||||
y = torch.add(x, x, alpha=a0)
|
||||
|
||||
self.assertTrue(sym_int_encountered)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -741,13 +741,15 @@ auto FunctionParameter::check(
|
||||
return true;
|
||||
}
|
||||
if (allow_numbers_as_tensors) {
|
||||
return THPUtils_checkScalar(obj) ||
|
||||
torch::is_symint_node(py::handle(obj)) ||
|
||||
torch::is_symfloat_node(py::handle(obj));
|
||||
return THPUtils_checkScalar(obj);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case ParameterType::SCALAR:
|
||||
if (THPUtils_checkScalar(obj)) {
|
||||
return true;
|
||||
}
|
||||
// fallthrough
|
||||
case ParameterType::COMPLEX:
|
||||
if (PyComplex_Check(obj)) {
|
||||
return true;
|
||||
@ -1498,6 +1500,9 @@ at::Tensor PythonArgs::tensor_slow(int i) {
|
||||
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
|
||||
} else if (THPUtils_checkDouble(obj)) {
|
||||
scalar = at::Scalar(THPUtils_unpackDouble(obj));
|
||||
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
|
||||
// because although Scalar supports SymInt/SymFloat, the subsequent
|
||||
// conversion to Tensor does not. Instead, do it out of band.
|
||||
} else if (torch::is_symint_node(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
// This scalar value doesn't matter, it shouldn't ever actually
|
||||
@ -1560,6 +1565,15 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
|
||||
if (PyComplex_Check(arg)) {
|
||||
return at::Scalar(THPUtils_unpackComplexDouble(arg));
|
||||
}
|
||||
|
||||
if (torch::is_symint_node(arg)) {
|
||||
return at::Scalar(py::cast<c10::SymInt>(arg));
|
||||
}
|
||||
|
||||
if (torch::is_symfloat_node(arg)) {
|
||||
return at::Scalar(py::cast<c10::SymFloat>(arg));
|
||||
}
|
||||
|
||||
return at::Scalar(THPUtils_unpackDouble(arg));
|
||||
}
|
||||
|
||||
|
@ -79,6 +79,7 @@
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
||||
inline bool is_symint_node(py::handle obj) {
|
||||
auto static tp_symn = py::type::of<c10::SymIntNodeImpl>();
|
||||
if (py::isinstance(obj, tp_symn)) {
|
||||
@ -98,6 +99,7 @@ inline bool is_symfloat_node(py::handle obj) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
||||
namespace pybind11 {
|
||||
@ -158,6 +160,17 @@ struct type_caster<c10::SymFloat> {
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
inline bool THPUtils_checkScalar(PyObject* obj) {
|
||||
#ifdef USE_NUMPY
|
||||
if (torch::utils::is_numpy_scalar(obj)) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
|
||||
torch::is_symint_node(py::handle(obj)) ||
|
||||
torch::is_symfloat_node(py::handle(obj));
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
|
||||
bool should_allow_numbers_as_tensors(const std::string& name);
|
||||
|
@ -139,15 +139,6 @@ inline bool THPUtils_checkDouble(PyObject* obj) {
|
||||
return PyFloat_Check(obj) || PyLong_Check(obj);
|
||||
}
|
||||
|
||||
inline bool THPUtils_checkScalar(PyObject* obj) {
|
||||
#ifdef USE_NUMPY
|
||||
if (torch::utils::is_numpy_scalar(obj)) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj);
|
||||
}
|
||||
|
||||
inline double THPUtils_unpackDouble(PyObject* obj) {
|
||||
if (PyFloat_Check(obj)) {
|
||||
return PyFloat_AS_DOUBLE(obj);
|
||||
|
Reference in New Issue
Block a user