mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is follow-up of #164912 to mark unused C++ parameters to improve code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165121 Approved by: https://github.com/Skylion007
884 lines
42 KiB
C++
884 lines
42 KiB
C++
#pragma once
|
|
|
|
#include <ATen/Context.h>
|
|
#include <c10/core/DeviceType.h>
|
|
#include <torch/csrc/autograd/autograd.h>
|
|
#include <torch/csrc/autograd/edge.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
|
#include <torch/csrc/jit/runtime/logging.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/jit/runtime/print_handler.h>
|
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
#include <torch/csrc/jit/serialization/pickle.h>
|
|
|
|
#include <ATen/ExpandUtils.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/WrapDimUtils.h>
|
|
#include <ATen/core/Dict.h>
|
|
#include <ATen/core/Generator.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <c10/core/Device.h>
|
|
#include <c10/core/thread_pool.h>
|
|
#include <c10/util/SmallVector.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
namespace torch::jit {
|
|
constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
|
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
|
}
|
|
|
|
constexpr inline c10::AliasAnalysisKind aliasAnalysisConservative() {
|
|
return c10::AliasAnalysisKind::CONSERVATIVE;
|
|
}
|
|
|
|
constexpr inline c10::AliasAnalysisKind aliasAnalysisSpecialCase() {
|
|
return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
|
|
}
|
|
|
|
template <class T>
|
|
c10::List<T> make_result_list(const TypePtr& elemType) {
|
|
return c10::List<T>();
|
|
}
|
|
|
|
template <>
|
|
c10::impl::GenericList make_result_list<IValue>(const TypePtr& elemType);
|
|
|
|
// As described in https://docs.python.org/3/library/functions.html#round
|
|
// When a number is exactly halfway between two integers, python builtin round
|
|
// function will round to even number. We use round(x/2)*2 to handle the
|
|
// special halfway case. For positive 'x', round(x/2)*2 =
|
|
// round((x_e + x_r)/2)*2 = x_e + round(x_r/2)*2, where x_e is an even integer,
|
|
// x_r is either 0.5 of 1.5, round(x_r/2)*2 results a 0 or 2, so the final
|
|
// result will always be a even number. Due to symmetricity, it also applies to
|
|
// negative cases.
|
|
inline double round_to_even(double a) {
|
|
return a - std::floor(a) == 0.5 ? (std::round(a * 0.5) * 2.0) : std::round(a);
|
|
}
|
|
|
|
// using the rules from python_arg_parser FunctionParameter::check
|
|
// tensor cannot have grad set, tensor must be 0 dim,
|
|
// and if the dest is an int the source must be integral type
|
|
void checkImplicitTensorToNum(const at::Tensor& t, bool toInt);
|
|
|
|
[[maybe_unused]] static int64_t floordiv(int64_t a, int64_t b) {
|
|
if (b == 0) {
|
|
throw std::runtime_error("division by 0");
|
|
}
|
|
if ((a > 0) == (b > 0)) {
|
|
// simple case, both have same sign
|
|
return a / b;
|
|
} else {
|
|
// in python division rounds down, it doesn't not truncate like in c++
|
|
auto r = lldiv(a, b);
|
|
return (r.rem) ? r.quot - 1 : r.quot;
|
|
}
|
|
}
|
|
TORCH_API void checkDoubleInRange(double a);
|
|
[[maybe_unused]] static int64_t floor(double a) {
|
|
checkDoubleInRange(a);
|
|
return std::floor(a);
|
|
}
|
|
[[maybe_unused]] static int64_t ceil(double a) {
|
|
checkDoubleInRange(a);
|
|
return std::ceil(a);
|
|
}
|
|
|
|
[[maybe_unused]] static int64_t gcd(int64_t a, int64_t b) {
|
|
while (b != 0) {
|
|
int64_t r = a % b;
|
|
a = b;
|
|
b = r;
|
|
}
|
|
// in python gcd returns non-negative values
|
|
return std::abs(a);
|
|
}
|
|
|
|
int64_t partProduct(int n, int m);
|
|
|
|
void loop(int n, int64_t& p, int64_t& r);
|
|
|
|
int nminussumofbits(int v);
|
|
|
|
int64_t factorial(int n);
|
|
static const double degToRad = std::acos(-1.0) / 180.0;
|
|
static const double radToDeg = 180.0 / std::acos(-1.0);
|
|
double degrees(double x);
|
|
double radians(double x);
|
|
|
|
// Convert an python index (which may be negative) into an index usable for a
|
|
// C++ container
|
|
|
|
// Equivalent to list.at(idx)
|
|
template <typename T>
|
|
decltype(auto) getItem(const c10::List<T>& list, int64_t idx) {
|
|
const int64_t list_size = list.size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
if (normalized_idx < 0 || normalized_idx >= list_size) {
|
|
throw std::out_of_range("list index out of range");
|
|
}
|
|
return list.get(normalized_idx);
|
|
}
|
|
|
|
template <typename T>
|
|
void setItem(const c10::List<T>& list, int64_t idx, T&& value) {
|
|
const int64_t list_size = list.size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
if (normalized_idx < 0 || normalized_idx >= list_size) {
|
|
throw std::out_of_range("list index out of range");
|
|
}
|
|
list.set(normalized_idx, std::forward<T>(value));
|
|
}
|
|
|
|
void listAppend(Stack& stack);
|
|
|
|
void listReverse(Stack& stack);
|
|
|
|
template <typename T>
|
|
void minList(Stack& stack) {
|
|
c10::List<T> a = pop(stack).to<c10::List<T>>();
|
|
c10::List<T> b = pop(stack).to<c10::List<T>>();
|
|
|
|
size_t min_size = std::min(a.size(), b.size());
|
|
for (const auto i : c10::irange(min_size)) {
|
|
if (a[i] == b[i]) {
|
|
continue;
|
|
}
|
|
|
|
push(stack, a[i] < b[i] ? a : b);
|
|
return;
|
|
}
|
|
|
|
push(stack, b.size() < a.size() ? b : a);
|
|
}
|
|
|
|
template <typename T>
|
|
void maxList(Stack& stack) {
|
|
c10::List<T> a = pop(stack).to<c10::List<T>>();
|
|
c10::List<T> b = pop(stack).to<c10::List<T>>();
|
|
|
|
size_t min_size = std::min(a.size(), b.size());
|
|
for (const auto i : c10::irange(min_size)) {
|
|
if (a[i] == b[i]) {
|
|
continue;
|
|
}
|
|
|
|
push(stack, a[i] > b[i] ? a : b);
|
|
return;
|
|
}
|
|
|
|
push(stack, b.size() > a.size() ? b : a);
|
|
}
|
|
|
|
void listPopImpl(Stack& stack, const char* empty_message);
|
|
|
|
void listPop(Stack& stack);
|
|
|
|
void listClear(Stack& stack);
|
|
|
|
void listDelete(Stack& stack);
|
|
|
|
void listInsert(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listRemove(Stack& stack) {
|
|
T elem = pop(stack).to<T>();
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
|
|
auto pos = std::find(list.begin(), list.end(), elem);
|
|
|
|
if (pos != list.end()) {
|
|
list.erase(pos);
|
|
} else {
|
|
TORCH_CHECK(false, "list.remove(x): x not in list");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void listMin(Stack& stack) {
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
size_t list_size = list.size();
|
|
if (list_size == 0) {
|
|
throw std::runtime_error("min() arg is an empty sequence");
|
|
}
|
|
|
|
T min_elem = list[0];
|
|
for (const auto i : c10::irange(1, list_size)) {
|
|
T elem = list[i];
|
|
min_elem = elem < min_elem ? elem : min_elem;
|
|
}
|
|
|
|
stack.push_back(min_elem);
|
|
}
|
|
|
|
template <typename T>
|
|
void listMax(Stack& stack) {
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
size_t list_size = list.size();
|
|
if (list_size == 0) {
|
|
throw std::runtime_error("max() arg is an empty sequence");
|
|
}
|
|
|
|
T max_elem = list[0];
|
|
for (const auto i : c10::irange(1, list_size)) {
|
|
T elem = list[i];
|
|
max_elem = elem > max_elem ? elem : max_elem;
|
|
}
|
|
|
|
stack.push_back(max_elem);
|
|
}
|
|
|
|
template <>
|
|
void listRemove<at::Tensor>(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listIndex(Stack& stack) {
|
|
T elem = pop(stack).to<T>();
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
|
|
auto pos = std::find(list.begin(), list.end(), elem);
|
|
|
|
if (pos != list.end()) {
|
|
push(stack, static_cast<int64_t>(std::distance(list.begin(), pos)));
|
|
} else {
|
|
TORCH_CHECK(false, "'", elem, "' is not in list");
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void listIndex<at::Tensor>(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listCount(Stack& stack) {
|
|
T elem = pop(stack).to<T>();
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
|
|
const int64_t count = std::count(list.begin(), list.end(), elem);
|
|
push(stack, count);
|
|
}
|
|
|
|
template <>
|
|
void listCount<at::Tensor>(Stack& stack);
|
|
|
|
void listExtend(Stack& stack);
|
|
|
|
void listCopy(Stack& stack);
|
|
|
|
void listSelect(Stack& stack);
|
|
|
|
void listLen(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listEq(Stack& stack) {
|
|
c10::List<T> b = pop(stack).to<c10::List<T>>();
|
|
c10::List<T> a = pop(stack).to<c10::List<T>>();
|
|
push(stack, a == b);
|
|
}
|
|
|
|
template <typename T>
|
|
void listNe(Stack& stack) {
|
|
c10::List<T> b = pop(stack).to<c10::List<T>>();
|
|
c10::List<T> a = pop(stack).to<c10::List<T>>();
|
|
push(stack, a != b);
|
|
}
|
|
|
|
inline bool tensor_list_equal(
|
|
const c10::List<at::Tensor>& a,
|
|
const c10::List<at::Tensor>& b) {
|
|
if (a.size() != b.size()) {
|
|
return false;
|
|
}
|
|
|
|
for (const auto i : c10::irange(a.size())) {
|
|
const at::Tensor& a_element = a[i];
|
|
const at::Tensor& b_element = b[i];
|
|
// This preserves Python's semantics, which uses eq() to compare two
|
|
// elements, then passes the result to bool().
|
|
// see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
|
|
const auto cmp_result = a_element.eq(b_element);
|
|
if (!at::native::is_nonzero(cmp_result)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Specialization for at::Tensor, since it doesn't define operator==
|
|
template <>
|
|
void listEq<at::Tensor>(Stack& stack);
|
|
|
|
// Specialization for at::Tensor, since it doesn't define operator==
|
|
template <>
|
|
void listNe<at::Tensor>(Stack& stack);
|
|
|
|
void listList(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listContains(Stack& stack) {
|
|
auto key = pop(stack).to<T>();
|
|
auto list = pop(stack).to<c10::List<T>>();
|
|
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
|
|
for (const T& item : list) {
|
|
if (item == key) {
|
|
push(stack, true);
|
|
return;
|
|
}
|
|
}
|
|
push(stack, false);
|
|
}
|
|
|
|
void listAdd(Stack& stack);
|
|
|
|
void listInplaceAdd(Stack& stack);
|
|
|
|
void listMulIntLeftInPlace(Stack& stack);
|
|
|
|
void listMulIntLeft(Stack& stack);
|
|
|
|
void listMulIntRight(Stack& stack);
|
|
|
|
void listSlice(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listSort(Stack& stack) {
|
|
bool reverse = pop(stack).toBool();
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) {
|
|
// FBCode errors without this check - "strict weak ordering"
|
|
// TODO: remove when possible, since it just slows down
|
|
// sorting and doesn't do anything useful
|
|
if (a == b) {
|
|
return false;
|
|
}
|
|
return (a < b) != reverse;
|
|
});
|
|
}
|
|
|
|
// Specialization for at::Tensor
|
|
template <>
|
|
void listSort<at::Tensor>(Stack& stack);
|
|
|
|
template <typename T>
|
|
void listCopyAndSort(Stack& stack) {
|
|
c10::List<T> list = pop(stack).to<c10::List<T>>();
|
|
auto list_copied = list.copy();
|
|
std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) {
|
|
// "strict weak ordering" issue - see other sort
|
|
if (a == b) {
|
|
return false;
|
|
}
|
|
return a < b;
|
|
});
|
|
push(stack, list_copied);
|
|
}
|
|
|
|
// Specialization for at::Tensor
|
|
template <>
|
|
void listCopyAndSort<at::Tensor>(Stack& stack);
|
|
|
|
void listSetItem(Stack& stack);
|
|
|
|
struct OperatorGeneratorArgs {
|
|
const char* schema_str;
|
|
bool isOperationCreator;
|
|
union {
|
|
void (*operation)(Stack&);
|
|
OperationCreator operationCreator;
|
|
};
|
|
AliasAnalysisKind aliasAnalysis;
|
|
|
|
explicit constexpr OperatorGeneratorArgs(
|
|
torch::detail::SelectiveStr<true> schema_str,
|
|
void (*op)(Stack&),
|
|
AliasAnalysisKind aa)
|
|
: schema_str(schema_str),
|
|
isOperationCreator(false),
|
|
operation(op),
|
|
aliasAnalysis(aa) {}
|
|
|
|
explicit constexpr OperatorGeneratorArgs(
|
|
torch::detail::SelectiveStr<true> schema_str,
|
|
OperationCreator opCreator,
|
|
AliasAnalysisKind aa)
|
|
: schema_str(schema_str),
|
|
isOperationCreator(true),
|
|
operationCreator(opCreator),
|
|
aliasAnalysis(aa) {}
|
|
|
|
template <typename... Args>
|
|
explicit constexpr OperatorGeneratorArgs(
|
|
torch::detail::SelectiveStr<false> /*unused*/,
|
|
Args... /*unused*/)
|
|
: schema_str(nullptr),
|
|
isOperationCreator(false),
|
|
operation(nullptr),
|
|
aliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE) {}
|
|
};
|
|
|
|
#define DEFINE_GENERIC_BINARY_OP( \
|
|
aten_op, op, int_float_result, complex_result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op \
|
|
".int_int(int a, int b) -> " #int_float_result), \
|
|
[](Stack& stack) { \
|
|
int64_t a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op \
|
|
".float_float(float a, float b) -> " #int_float_result), \
|
|
[](Stack& stack) { \
|
|
double a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op \
|
|
".complex_complex(complex a, complex b) -> " #complex_result), \
|
|
[](Stack& stack) { \
|
|
c10::complex<double> a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
// define implementations for primitive number ops
|
|
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \
|
|
[](Stack& stack) { \
|
|
int64_t a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, int_op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op ".float(float a, float b) -> " #float_result), \
|
|
[](Stack& stack) { \
|
|
double a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, float_op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_INT_FLOAT_OP(aten_op, op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op \
|
|
".int_float(int a, float b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
int64_t a; \
|
|
double b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op \
|
|
".float_int(float a, int b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
double a; \
|
|
int64_t b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_INT_OP(aten_op, op) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> int"), \
|
|
[](Stack& stack) { \
|
|
int64_t a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_STR_CMP_OP(aten_op, op) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".str(str a, str b) -> bool"), \
|
|
[](Stack& stack) { \
|
|
auto b = pop(stack).toStringRef(); \
|
|
auto a = pop(stack).toStringRef(); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
// define a primitive op over Scalar operands.
|
|
// it's necessary to register this overload following
|
|
// int/float variations to avoid trapping Scalar args
|
|
// in unintended implicit conversions
|
|
#define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
|
|
aten_op, int_op, float_op, result, string_val) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op string_val \
|
|
"(Scalar a, Scalar b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
IValue x, y; \
|
|
pop(stack, x, y); \
|
|
if (x.isDouble()) { \
|
|
if (y.isDouble()) { \
|
|
double a = x.toDouble(); \
|
|
double b = y.toDouble(); \
|
|
push(stack, float_op); \
|
|
} else { \
|
|
double a = x.toDouble(); \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, float_op); \
|
|
} \
|
|
} else { \
|
|
if (y.isDouble()) { \
|
|
int64_t a = x.toInt(); \
|
|
double b = y.toDouble(); \
|
|
push(stack, float_op); \
|
|
} else { \
|
|
int64_t a = x.toInt(); \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, int_op); \
|
|
} \
|
|
} \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \
|
|
DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
|
|
aten_op, int_op, float_op, result, "")
|
|
|
|
#define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( \
|
|
aten_op, int_op, float_op, result) \
|
|
DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \
|
|
aten_op, int_op, float_op, result, ".Scalar_Scalar")
|
|
|
|
#define DEFINE_BINARY_OP(aten_op, op) \
|
|
DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, float), \
|
|
DEFINE_SCALAR_BINARY_OP(aten_op, op, op, Scalar)
|
|
|
|
#define DEFINE_BINARY_FLOAT_OP(aten_op, op) \
|
|
DEFINE_GENERIC_OP(aten_op, op, op, float, float), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, float), \
|
|
DEFINE_SCALAR_BINARY_OP(aten_op, op, op, float)
|
|
|
|
#define DEFINE_COMPARISON_OP(aten_op, op) \
|
|
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, bool), \
|
|
DEFINE_SCALAR_BINARY_OP(aten_op, op, op, bool), \
|
|
DEFINE_STR_CMP_OP(aten_op, op)
|
|
|
|
#define DEFINE_UNARY_INT_OP(aten_op, op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a) -> " #result), \
|
|
[](Stack& stack) { \
|
|
int64_t a; \
|
|
pop(stack, a); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_UNARY_FLOAT_OP(aten_op, op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".float(float a) -> " #result), \
|
|
[](Stack& stack) { \
|
|
double a; \
|
|
pop(stack, a); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \
|
|
DEFINE_UNARY_INT_OP(aten_op, op, int_result), \
|
|
DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
|
|
[](Stack& stack) { \
|
|
IValue x; \
|
|
pop(stack, x); \
|
|
if (x.isDouble()) { \
|
|
double a = x.toDouble(); \
|
|
push(stack, static_cast<float_result>(op)); \
|
|
} else { \
|
|
int64_t a = x.toInt(); \
|
|
push(stack, static_cast<int_result>(op)); \
|
|
} \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
#define DEFINE_BOOL_OP(aten_op, op) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".bool(bool a, bool b) -> bool"), \
|
|
[](Stack& stack) { \
|
|
bool a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
#define DEFINE_STRING_OP(op_name, string_op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#op_name ".str(str a, str b) ->" #result), \
|
|
[](Stack& stack) { \
|
|
auto b = pop(stack).toStringRef(); \
|
|
auto a = pop(stack).toStringRef(); \
|
|
push(stack, string_op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
//-----------------------------------------------------------------------------
|
|
//-----------------------------------------------------------------------------
|
|
//-----------------------------------------------------------------------------
|
|
//-----------------------------------------------------------------------------
|
|
#define DEFINE_UNARY_COMPLEX_OP(aten_op, op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".complex(complex a) -> " #result), \
|
|
[](Stack& stack) { \
|
|
c10::complex<double> a; \
|
|
pop(stack, a); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
// Some complex unary ops (like abs, angle) return real valued output, but most
|
|
// other unary ops return complex valued output. So, this macro is used in the
|
|
// former case where we can explicitly pass complex_result_cast argument, which
|
|
// is set to c10::complex<float> in the macro `DEFINE_UNARY_OP_WITH_COMPLEX`
|
|
// defined below.
|
|
#define DEFINE_UNARY_OP_WITH_COMPLEX_CAST( \
|
|
aten_op, \
|
|
op, \
|
|
int_result, \
|
|
float_result, \
|
|
complex_result, \
|
|
complex_result_cast) \
|
|
DEFINE_UNARY_INT_OP(aten_op, op, int_result), \
|
|
DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \
|
|
DEFINE_UNARY_COMPLEX_OP(aten_op, op, complex_result), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \
|
|
[](Stack& stack) { \
|
|
IValue x; \
|
|
pop(stack, x); \
|
|
if (x.isDouble()) { \
|
|
double a = x.toDouble(); \
|
|
push(stack, static_cast<float_result>(op)); \
|
|
} else if (x.isComplexDouble()) { \
|
|
c10::complex<double> a = x.toComplexDouble(); \
|
|
push(stack, static_cast<complex_result_cast>(op)); \
|
|
} else { \
|
|
int64_t a = x.toInt(); \
|
|
push(stack, static_cast<int_result>(op)); \
|
|
} \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_UNARY_OP_WITH_COMPLEX(aten_op, op, int_result, float_result) \
|
|
DEFINE_UNARY_OP_WITH_COMPLEX_CAST( \
|
|
aten_op, op, int_result, float_result, complex, c10::complex<double>)
|
|
|
|
#define DEFINE_GENERIC_OP_WITH_COMPLEX( \
|
|
aten_op, \
|
|
int_op, \
|
|
float_op, \
|
|
complex_op, \
|
|
int_result, \
|
|
float_result, \
|
|
complex_result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \
|
|
[](Stack& stack) { \
|
|
int64_t a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, int_op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op ".complex(complex a, complex b) -> " #complex_result), \
|
|
[](Stack& stack) { \
|
|
c10::complex<double> a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, complex_op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op ".float(float a, float b) -> " #float_result), \
|
|
[](Stack& stack) { \
|
|
double a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, float_op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_INT_COMPLEX_OP(aten_op, op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op \
|
|
".int_complex(int a, complex b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
int64_t a; \
|
|
c10::complex<double> b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op ".complex_int(complex a, int b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
c10::complex<double> a; \
|
|
int64_t b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_FLOAT_COMPLEX_OP(aten_op, op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op ".float_complex(float a, complex b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
double a; \
|
|
c10::complex<double> b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema()), \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA( \
|
|
#aten_op ".complex_float(complex a, float b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
c10::complex<double> a; \
|
|
double b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
|
|
aten_op, int_op, float_op, complex_op, result, string_val) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op string_val \
|
|
"(Scalar a, Scalar b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
IValue x, y; \
|
|
pop(stack, x, y); \
|
|
if (x.isComplexDouble()) { \
|
|
c10::complex<double> a = x.toComplexDouble(); \
|
|
if (y.isComplexDouble()) { \
|
|
c10::complex<double> b = y.toComplexDouble(); \
|
|
push(stack, complex_op); \
|
|
} else if (y.isDouble()) { \
|
|
double b = y.toDouble(); \
|
|
push(stack, complex_op); \
|
|
} else { \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, complex_op); \
|
|
} \
|
|
} else if (x.isDouble()) { \
|
|
double a = x.toDouble(); \
|
|
if (y.isComplexDouble()) { \
|
|
c10::complex<double> b = y.toComplexDouble(); \
|
|
push(stack, complex_op); \
|
|
} else if (y.isDouble()) { \
|
|
double b = y.toDouble(); \
|
|
push(stack, float_op); \
|
|
} else { \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, float_op); \
|
|
} \
|
|
} else { \
|
|
int64_t a = x.toInt(); \
|
|
if (y.isComplexDouble()) { \
|
|
c10::complex<double> b = y.toComplexDouble(); \
|
|
push(stack, complex_op); \
|
|
} else if (y.isDouble()) { \
|
|
double b = y.toDouble(); \
|
|
push(stack, float_op); \
|
|
} else { \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, int_op); \
|
|
} \
|
|
} \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \
|
|
aten_op, int_op, float_op, complex_op, result) \
|
|
OperatorGeneratorArgs( \
|
|
TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \
|
|
[](Stack& stack) { \
|
|
IValue x, y; \
|
|
pop(stack, x, y); \
|
|
if (x.isComplexDouble()) { \
|
|
c10::complex<double> a = x.toComplexDouble(); \
|
|
if (y.isComplexDouble()) { \
|
|
c10::complex<double> b = y.toComplexDouble(); \
|
|
push(stack, complex_op); \
|
|
} else if (y.isDouble()) { \
|
|
double b = y.toDouble(); \
|
|
push(stack, complex_op); \
|
|
} \
|
|
} else if (x.isDouble()) { \
|
|
double a = x.toDouble(); \
|
|
if (y.isComplexDouble()) { \
|
|
c10::complex<double> b = y.toComplexDouble(); \
|
|
push(stack, complex_op); \
|
|
} else if (y.isDouble()) { \
|
|
double b = y.toDouble(); \
|
|
push(stack, float_op); \
|
|
} else { \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, float_op); \
|
|
} \
|
|
} else { \
|
|
int64_t a = x.toInt(); \
|
|
if (y.isDouble()) { \
|
|
double b = y.toDouble(); \
|
|
push(stack, float_op); \
|
|
} else if (y.isInt()) { \
|
|
int64_t b = y.toInt(); \
|
|
push(stack, int_op); \
|
|
} \
|
|
} \
|
|
}, \
|
|
aliasAnalysisFromSchema())
|
|
|
|
#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX( \
|
|
aten_op, int_op, float_op, complex_op, result) \
|
|
DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \
|
|
aten_op, int_op, float_op, complex_op, result, "")
|
|
|
|
#define DEFINE_BINARY_OP_WITH_COMPLEX(aten_op, op) \
|
|
DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, int, float, complex), \
|
|
DEFINE_INT_COMPLEX_OP(aten_op, op, complex), \
|
|
DEFINE_FLOAT_COMPLEX_OP(aten_op, op, complex), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, float), \
|
|
DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(aten_op, op, op, op, Scalar)
|
|
|
|
#define DEFINE_COMPARISON_OP_WITH_COMPLEX(aten_op, op) \
|
|
DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, bool, bool, bool), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, bool), \
|
|
DEFINE_FLOAT_COMPLEX_OP(aten_op, op, bool), \
|
|
DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \
|
|
aten_op, op, op, op, bool), \
|
|
DEFINE_STR_CMP_OP(aten_op, op)
|
|
|
|
TORCH_API at::Generator make_generator_for_device(
|
|
c10::Device device,
|
|
std::optional<int64_t> seed = std::nullopt);
|
|
|
|
} // namespace torch::jit
|