mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This PR adds unused parameter name comments in C++ declarations to improve code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164912 Approved by: https://github.com/Skylion007
467 lines
14 KiB
C++
467 lines
14 KiB
C++
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <stdexcept>
|
|
#include <type_traits>
|
|
#include <utility>
|
|
|
|
#include <c10/core/OptionalRef.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/core/SymBool.h>
|
|
#include <c10/core/SymFloat.h>
|
|
#include <c10/core/SymInt.h>
|
|
#include <c10/core/SymNodeImpl.h>
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Half.h>
|
|
#include <c10/util/TypeCast.h>
|
|
#include <c10/util/complex.h>
|
|
#include <c10/util/intrusive_ptr.h>
|
|
#include <c10/util/overflows.h>
|
|
|
|
namespace c10 {
|
|
|
|
/**
|
|
* Scalar represents a 0-dimensional tensor which contains a single element.
|
|
* Unlike a tensor, numeric literals (in C++) are implicitly convertible to
|
|
* Scalar (which is why, for example, we provide both add(Tensor) and
|
|
* add(Scalar) overloads for many operations). It may also be used in
|
|
* circumstances where you statically know a tensor is 0-dim and single size,
|
|
* but don't know its type.
|
|
*/
|
|
class C10_API Scalar {
|
|
public:
|
|
Scalar() : Scalar(int64_t(0)) {}
|
|
|
|
void destroy() {
|
|
if (Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag) {
|
|
raw::intrusive_ptr::decref(v.p);
|
|
v.p = nullptr;
|
|
}
|
|
}
|
|
|
|
~Scalar() {
|
|
destroy();
|
|
}
|
|
|
|
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
|
Scalar(type vv) : Scalar(vv, true) {}
|
|
|
|
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
|
|
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
|
|
AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR)
|
|
|
|
// Helper constructors to allow Scalar creation from long and long long types
|
|
// As std::is_same_v<long, long long> is false(except Android), one needs to
|
|
// provide a constructor from either long or long long in addition to one from
|
|
// int64_t
|
|
#if defined(__APPLE__) || defined(__MACOSX)
|
|
static_assert(
|
|
std::is_same_v<long long, int64_t>,
|
|
"int64_t is the same as long long on MacOS");
|
|
Scalar(long vv) : Scalar(vv, true) {}
|
|
#endif
|
|
#if defined(_MSC_VER)
|
|
static_assert(
|
|
std::is_same_v<long long, int64_t>,
|
|
"int64_t is the same as long long on Windows");
|
|
Scalar(long vv) : Scalar(vv, true) {}
|
|
#endif
|
|
#if defined(__linux__) && !defined(__ANDROID__)
|
|
static_assert(
|
|
sizeof(void*) != 8 || std::is_same_v<long, int64_t>,
|
|
"int64_t is the same as long on 64 bit Linux");
|
|
#if LONG_MAX != INT_MAX
|
|
Scalar(long long vv) : Scalar(vv, true) {}
|
|
#endif /* not 32-bit system */
|
|
#endif
|
|
|
|
Scalar(uint16_t vv) : Scalar(vv, true) {}
|
|
Scalar(uint32_t vv) : Scalar(vv, true) {}
|
|
Scalar(uint64_t vv) {
|
|
if (vv > static_cast<uint64_t>(INT64_MAX)) {
|
|
tag = Tag::HAS_u;
|
|
v.u = vv;
|
|
} else {
|
|
tag = Tag::HAS_i;
|
|
// NB: no need to use convert, we've already tested convertibility
|
|
v.i = static_cast<int64_t>(vv);
|
|
}
|
|
}
|
|
|
|
#undef DEFINE_IMPLICIT_CTOR
|
|
|
|
// Value* is both implicitly convertible to SymbolicVariable and bool which
|
|
// causes ambiguity error. Specialized constructor for bool resolves this
|
|
// problem.
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<std::is_same_v<T, bool>, bool>* = nullptr>
|
|
Scalar(T vv) : tag(Tag::HAS_b) {
|
|
v.i = convert<int64_t, bool>(vv);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<std::is_same_v<T, c10::SymBool>, bool>* =
|
|
nullptr>
|
|
Scalar(T vv) : tag(Tag::HAS_sb) {
|
|
v.i = convert<int64_t, c10::SymBool>(vv);
|
|
}
|
|
|
|
#define DEFINE_ACCESSOR(type, name) \
|
|
type to##name() const { \
|
|
if (Tag::HAS_d == tag) { \
|
|
return checked_convert<type, double>(v.d, #type); \
|
|
} else if (Tag::HAS_z == tag) { \
|
|
return checked_convert<type, c10::complex<double>>(v.z, #type); \
|
|
} else if (Tag::HAS_sd == tag) { \
|
|
return checked_convert<type, double>( \
|
|
toSymFloat().guard_float(__FILE__, __LINE__), #type); \
|
|
} \
|
|
if (Tag::HAS_b == tag) { \
|
|
return checked_convert<type, bool>(v.i, #type); \
|
|
} else if (Tag::HAS_i == tag) { \
|
|
return checked_convert<type, int64_t>(v.i, #type); \
|
|
} else if (Tag::HAS_u == tag) { \
|
|
return checked_convert<type, uint64_t>(v.u, #type); \
|
|
} else if (Tag::HAS_si == tag) { \
|
|
return checked_convert<type, int64_t>( \
|
|
toSymInt().guard_int(__FILE__, __LINE__), #type); \
|
|
} else if (Tag::HAS_sb == tag) { \
|
|
return checked_convert<type, int64_t>( \
|
|
toSymBool().guard_bool(__FILE__, __LINE__), #type); \
|
|
} \
|
|
TORCH_CHECK(false) \
|
|
}
|
|
|
|
// TODO: Support ComplexHalf accessor
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ACCESSOR)
|
|
DEFINE_ACCESSOR(uint16_t, UInt16)
|
|
DEFINE_ACCESSOR(uint32_t, UInt32)
|
|
DEFINE_ACCESSOR(uint64_t, UInt64)
|
|
|
|
#undef DEFINE_ACCESSOR
|
|
|
|
SymInt toSymInt() const {
|
|
if (Tag::HAS_si == tag) {
|
|
return c10::SymInt(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
|
static_cast<SymNodeImpl*>(v.p)));
|
|
} else {
|
|
return toLong();
|
|
}
|
|
}
|
|
|
|
SymFloat toSymFloat() const {
|
|
if (Tag::HAS_sd == tag) {
|
|
return c10::SymFloat(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
|
static_cast<SymNodeImpl*>(v.p)));
|
|
} else {
|
|
return toDouble();
|
|
}
|
|
}
|
|
|
|
SymBool toSymBool() const {
|
|
if (Tag::HAS_sb == tag) {
|
|
return c10::SymBool(intrusive_ptr<SymNodeImpl>::reclaim_copy(
|
|
static_cast<SymNodeImpl*>(v.p)));
|
|
} else {
|
|
return toBool();
|
|
}
|
|
}
|
|
|
|
// also support scalar.to<int64_t>();
|
|
// Deleted for unsupported types, but specialized below for supported types
|
|
template <typename T>
|
|
T to() const = delete;
|
|
|
|
// audit uses of data_ptr
|
|
const void* data_ptr() const {
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
|
return static_cast<const void*>(&v);
|
|
}
|
|
|
|
bool isFloatingPoint() const {
|
|
return Tag::HAS_d == tag || Tag::HAS_sd == tag;
|
|
}
|
|
|
|
[[deprecated(
|
|
"isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.")]] bool
|
|
isIntegral() const {
|
|
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag;
|
|
}
|
|
|
|
bool isIntegral(bool includeBool) const {
|
|
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag ||
|
|
(includeBool && isBoolean());
|
|
}
|
|
|
|
// See Note [Meaning of HAS_u]
|
|
bool isUnsigned() const {
|
|
return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0);
|
|
}
|
|
|
|
bool isComplex() const {
|
|
return Tag::HAS_z == tag;
|
|
}
|
|
bool isBoolean() const {
|
|
return Tag::HAS_b == tag || Tag::HAS_sb == tag;
|
|
}
|
|
|
|
// you probably don't actually want these; they're mostly for testing
|
|
bool isSymInt() const {
|
|
return Tag::HAS_si == tag;
|
|
}
|
|
bool isSymFloat() const {
|
|
return Tag::HAS_sd == tag;
|
|
}
|
|
bool isSymBool() const {
|
|
return Tag::HAS_sb == tag;
|
|
}
|
|
|
|
bool isSymbolic() const {
|
|
return Tag::HAS_si == tag || Tag::HAS_sd == tag || Tag::HAS_sb == tag;
|
|
}
|
|
|
|
C10_ALWAYS_INLINE Scalar& operator=(Scalar&& other) noexcept {
|
|
if (&other == this) {
|
|
return *this;
|
|
}
|
|
|
|
destroy();
|
|
moveFrom(std::move(other));
|
|
return *this;
|
|
}
|
|
|
|
C10_ALWAYS_INLINE Scalar& operator=(const Scalar& other) {
|
|
if (&other == this) {
|
|
return *this;
|
|
}
|
|
|
|
*this = Scalar(other);
|
|
return *this;
|
|
}
|
|
|
|
Scalar operator-() const;
|
|
Scalar conj() const;
|
|
Scalar log() const;
|
|
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
|
|
bool equal(T num) const {
|
|
if (isComplex()) {
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
|
auto val = v.z;
|
|
return (val.real() == num) && (val.imag() == T());
|
|
} else if (isFloatingPoint()) {
|
|
return toDouble() == num;
|
|
} else if (tag == Tag::HAS_i) {
|
|
if (overflows<T>(v.i, /* strict_unsigned */ true)) {
|
|
return false;
|
|
} else {
|
|
return static_cast<T>(v.i) == num;
|
|
}
|
|
} else if (tag == Tag::HAS_u) {
|
|
if (overflows<T>(v.u, /* strict_unsigned */ true)) {
|
|
return false;
|
|
} else {
|
|
return static_cast<T>(v.u) == num;
|
|
}
|
|
} else if (tag == Tag::HAS_si) {
|
|
TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality");
|
|
} else if (isBoolean()) {
|
|
// boolean scalar does not equal to a non boolean value
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
|
return false;
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(false);
|
|
}
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
|
bool equal(T num) const {
|
|
if (isComplex()) {
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
|
return v.z == num;
|
|
} else if (isFloatingPoint()) {
|
|
return (toDouble() == num.real()) && (num.imag() == T());
|
|
} else if (tag == Tag::HAS_i) {
|
|
if (overflows<T>(v.i, /* strict_unsigned */ true)) {
|
|
return false;
|
|
} else {
|
|
return static_cast<T>(v.i) == num.real() && num.imag() == T();
|
|
}
|
|
} else if (tag == Tag::HAS_u) {
|
|
if (overflows<T>(v.u, /* strict_unsigned */ true)) {
|
|
return false;
|
|
} else {
|
|
return static_cast<T>(v.u) == num.real() && num.imag() == T();
|
|
}
|
|
} else if (tag == Tag::HAS_si) {
|
|
TORCH_INTERNAL_ASSERT(false, "NYI SymInt equality");
|
|
} else if (isBoolean()) {
|
|
// boolean scalar does not equal to a non boolean value
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
|
return false;
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(false);
|
|
}
|
|
}
|
|
|
|
bool equal(bool num) const {
|
|
if (isBoolean()) {
|
|
TORCH_INTERNAL_ASSERT(!isSymbolic());
|
|
return static_cast<bool>(v.i) == num;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
ScalarType type() const {
|
|
if (isComplex()) {
|
|
return ScalarType::ComplexDouble;
|
|
} else if (isFloatingPoint()) {
|
|
return ScalarType::Double;
|
|
} else if (isIntegral(/*includeBool=*/false)) {
|
|
// Represent all integers as long, UNLESS it is unsigned and therefore
|
|
// unrepresentable as long
|
|
if (Tag::HAS_u == tag) {
|
|
return ScalarType::UInt64;
|
|
}
|
|
return ScalarType::Long;
|
|
} else if (isBoolean()) {
|
|
return ScalarType::Bool;
|
|
} else {
|
|
throw std::runtime_error("Unknown scalar type.");
|
|
}
|
|
}
|
|
|
|
Scalar(Scalar&& rhs) noexcept : tag(rhs.tag) {
|
|
moveFrom(std::move(rhs));
|
|
}
|
|
|
|
Scalar(const Scalar& rhs) : tag(rhs.tag), v(rhs.v) {
|
|
if (isSymbolic()) {
|
|
c10::raw::intrusive_ptr::incref(v.p);
|
|
}
|
|
}
|
|
|
|
Scalar(c10::SymInt si) {
|
|
if (auto m = si.maybe_as_int()) {
|
|
tag = Tag::HAS_i;
|
|
v.i = *m;
|
|
} else {
|
|
tag = Tag::HAS_si;
|
|
v.p = std::move(si).release();
|
|
}
|
|
}
|
|
|
|
Scalar(c10::SymFloat sd) {
|
|
if (sd.is_symbolic()) {
|
|
tag = Tag::HAS_sd;
|
|
v.p = std::move(sd).release();
|
|
} else {
|
|
tag = Tag::HAS_d;
|
|
v.d = sd.as_float_unchecked();
|
|
}
|
|
}
|
|
|
|
Scalar(c10::SymBool sb) {
|
|
if (auto m = sb.maybe_as_bool()) {
|
|
tag = Tag::HAS_b;
|
|
v.i = *m;
|
|
} else {
|
|
tag = Tag::HAS_sb;
|
|
v.p = std::move(sb).release();
|
|
}
|
|
}
|
|
|
|
// We can't set v in the initializer list using the
|
|
// syntax v{ .member = ... } because it doesn't work on MSVC
|
|
private:
|
|
enum class Tag { HAS_d, HAS_i, HAS_u, HAS_z, HAS_b, HAS_sd, HAS_si, HAS_sb };
|
|
|
|
// Note [Meaning of HAS_u]
|
|
// ~~~~~~~~~~~~~~~~~~~~~~~
|
|
// HAS_u is a bit special. On its face, it just means that we
|
|
// are holding an unsigned integer. However, we generally don't
|
|
// distinguish between different bit sizes in Scalar (e.g., we represent
|
|
// float as double), instead, it represents a mathematical notion
|
|
// of some quantity (integral versus floating point). So actually,
|
|
// HAS_u is used solely to represent unsigned integers that could
|
|
// not be represented as a signed integer. That means only uint64_t
|
|
// potentially can get this tag; smaller types like uint8_t fits into a
|
|
// regular int and so for BC reasons we keep as an int.
|
|
|
|
// NB: assumes that self has already been cleared
|
|
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
|
C10_ALWAYS_INLINE void moveFrom(Scalar&& rhs) noexcept {
|
|
v = rhs.v;
|
|
tag = rhs.tag;
|
|
if (rhs.tag == Tag::HAS_si || rhs.tag == Tag::HAS_sd ||
|
|
rhs.tag == Tag::HAS_sb) {
|
|
// Move out of scalar
|
|
rhs.tag = Tag::HAS_i;
|
|
rhs.v.i = 0;
|
|
}
|
|
}
|
|
|
|
Tag tag;
|
|
|
|
union v_t {
|
|
double d{};
|
|
int64_t i;
|
|
// See Note [Meaning of HAS_u]
|
|
uint64_t u;
|
|
c10::complex<double> z;
|
|
c10::intrusive_ptr_target* p;
|
|
// NOLINTNEXTLINE(modernize-use-equals-default)
|
|
v_t() {} // default constructor
|
|
} v;
|
|
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<
|
|
std::is_integral_v<T> && !std::is_same_v<T, bool>,
|
|
bool>* = nullptr>
|
|
Scalar(T vv, bool /*unused*/) : tag(Tag::HAS_i) {
|
|
v.i = convert<decltype(v.i), T>(vv);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<
|
|
!std::is_integral_v<T> && !c10::is_complex<T>::value,
|
|
bool>* = nullptr>
|
|
Scalar(T vv, bool /*unused*/) : tag(Tag::HAS_d) {
|
|
v.d = convert<decltype(v.d), T>(vv);
|
|
}
|
|
|
|
template <
|
|
typename T,
|
|
typename std::enable_if_t<c10::is_complex<T>::value, bool>* = nullptr>
|
|
Scalar(T vv, bool /*unused*/) : tag(Tag::HAS_z) {
|
|
v.z = convert<decltype(v.z), T>(vv);
|
|
}
|
|
};
|
|
|
|
using OptionalScalarRef = c10::OptionalRef<Scalar>;
|
|
|
|
// define the scalar.to<int64_t>() specializations
|
|
#define DEFINE_TO(T, name) \
|
|
template <> \
|
|
inline T Scalar::to<T>() const { \
|
|
return to##name(); \
|
|
}
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO)
|
|
DEFINE_TO(uint16_t, UInt16)
|
|
DEFINE_TO(uint32_t, UInt32)
|
|
DEFINE_TO(uint64_t, UInt64)
|
|
#undef DEFINE_TO
|
|
|
|
} // namespace c10
|