Files
pytorch/torch/csrc/jit/runtime/vararg_functions.cpp
FFFrog af0bc75460 Remove deprecated alias macro(1/3) (#137556)
**Detailed Descriptions:**
- Remove AT_ERROR Macro

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137556
Approved by: https://github.com/ezyang
2024-10-21 17:32:32 +00:00

435 lines
12 KiB
C++

#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <ATen/Functions.h>
#include <ATen/Tensor.h>
#include <ATen/core/class_type.h>
#include <c10/util/irange.h>
namespace torch::jit {
namespace {
static constexpr int defaultPrecision = 6;
// IValue tags are intentionally private, so we need additional logic to cast
// the IValue type to the specified format.
void addFormattedArg(
char key,
const IValue& ival,
std::stringstream& ss,
int precision = defaultPrecision) {
// TODO: Implement precision-based formatting
std::stringstream tmp;
switch (key) {
case 'd':
case 'i':
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" requires a number for formatting, but got ",
ival.tagKind());
if (ival.isInt()) {
ss << ival.toInt();
} else {
ss << static_cast<int>(ival.toDouble());
}
break;
case 'e':
case 'E':
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" requires a number for formatting, but got ",
ival.tagKind());
tmp << std::setprecision(precision) << std::scientific;
if (key == 'E') {
tmp << std::uppercase;
}
if (ival.isInt()) {
tmp << static_cast<float>(ival.toInt());
} else {
tmp << static_cast<float>(ival.toDouble());
}
ss << tmp.str();
break;
case 'f':
case 'F':
TORCH_CHECK(
ival.isScalar(),
"%",
key,
" requires a number for formatting, but got ",
ival.tagKind());
tmp << std::setprecision(precision) << std::fixed;
if (ival.isInt()) {
tmp << static_cast<float>(ival.toInt());
} else {
tmp << static_cast<float>(ival.toDouble());
}
ss << tmp.str();
break;
case 'c':
TORCH_CHECK(
ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1),
"%",
key,
" requires an int or char for formatting, but got ",
ival.tagKind());
if (ival.isInt()) {
ss << static_cast<char>(ival.toInt());
} else {
ss << ival.toStringRef();
}
break;
case 's':
if (ival.isString()) {
ss << ival.toStringRef();
} else {
ss << ival;
}
break;
default:
TORCH_CHECK(
false,
"The specifier %",
key,
" is not supported in TorchScript format strings");
}
}
} // namespace
void tupleUnpack(Stack& stack) {
auto tuple = pop(stack).toTuple();
stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
}
void format(Stack& stack, size_t num_inputs) {
TORCH_CHECK(
num_inputs > 0 && num_inputs <= stack.size(),
"Invalid number of inputs for format string: ",
num_inputs);
// static const std::regex unsupported_options("\\{(.*?)\\}");
auto format = peek(stack, 0, num_inputs).toStringRef();
// // Temporally comment out the warning message because of
// // "StdRegexIsAwful" internal Lint error, to prevent sev
// // of std::regex from PT mobile.
// if (std::regex_search(format, unsupported_options)) {
// TORCH_WARN("Format options are not supported.");
// }
auto args = last(stack, num_inputs - 1);
std::stringstream ss;
for (size_t begin = 0, used_args = 0; true; ++used_args) {
size_t loc = format.find("{}", begin);
if (loc == std::string::npos) {
ss << format.substr(begin);
break;
}
ss << format.substr(begin, loc - begin);
if (used_args >= args.size()) {
TORCH_CHECK(false, "Too few arguments for format string: ", format);
}
ss << args[used_args];
begin = loc + 2;
}
drop(stack, num_inputs);
push(stack, ss.str());
}
void einsum(Stack& stack, size_t num_inputs) {
TORCH_CHECK(
num_inputs >= 2,
"einsum(): must specify the equation string and at least one operand, ",
"or at least one operand and its subscripts list");
const auto args = last(stack, num_inputs);
// Convert the subscript list format which is an interleaving of operand and
// its subscripts list with an optional output subscripts list at the end
// (see documentation for more details on this) to the equation string
// format by creating the equation string from the subscripts list and
// grouping the input operands into a tensorlist (List[Tensor]).
std::stringstream ss;
auto parse_sublist = [&ss](const c10::List<int64_t>& l, size_t arg_num) {
for (const auto i : c10::irange(l.size())) {
TORCH_CHECK(
l[i] >= 0 && l[i] < 52,
"einsum(): expected subscript ",
i,
" in argument ",
arg_num,
" to be within the range [0, 52), but got ",
l[i]);
if (l[i] < 26) {
ss << static_cast<char>(l[i] + 'A');
} else {
ss << static_cast<char>(l[i] - 26 + 'a');
}
}
};
// Parse subscripts for input operands
for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) {
TORCH_CHECK(
args[i].isIntList(),
"einsum(): expected List[int] in argument ",
i,
", but got ",
args[i].type()->repr_str());
parse_sublist(args[i].toIntList(), i);
if (i + 2 < num_inputs) {
ss << ',';
}
}
// Parse optional output subscripts (provided if #args is odd)
if (num_inputs % 2 == 1) {
TORCH_CHECK(
args.back().isIntList(),
"einsum(): expected List[int] in argument ",
num_inputs - 1,
", but got ",
args.back().type()->repr_str());
ss << "->";
parse_sublist(args.back().toIntList(), num_inputs - 1);
}
const auto equation = ss.str();
std::vector<at::Tensor> operands;
// Parse input operands
const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs;
for (auto i = decltype(num_inputs){0}; i < end; i += 2) {
TORCH_CHECK(
args[i].isTensor(),
"einsum(): expected Tensor in argument ",
i,
", but got ",
args[i].type()->repr_str());
operands.emplace_back(args[i].toTensor());
}
drop(stack, num_inputs);
push(stack, at::einsum(equation, operands));
}
void percentFormat(Stack& stack, size_t num_inputs) {
auto format_str = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1)[0];
size_t args_size = 1; // assumed size
if (args.isTuple()) {
args_size = args.toTupleRef().elements().size();
}
std::stringstream ss;
size_t used_args = 0;
size_t begin = 0;
while (true) {
size_t percent_idx = format_str.find('%', begin);
if (percent_idx == std::string::npos) {
ss << format_str.substr(begin);
break;
}
size_t format_idx = percent_idx + 1;
TORCH_CHECK(
percent_idx < format_str.length() - 1, "Incomplete format specifier");
ss << format_str.substr(begin, percent_idx - begin);
if (format_str.at(format_idx) == '%') {
ss << '%';
begin = percent_idx + 2; // skip the `%` and the format specifier
continue;
}
TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
char key = format_str.at(format_idx);
IValue arg;
if (args.isTuple()) {
arg = args.toTupleRef().elements()[used_args];
} else {
arg = args;
}
addFormattedArg(key, arg, ss);
begin = percent_idx + 2;
++used_args;
}
TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
drop(stack, num_inputs);
push(stack, ss.str());
}
void listUnpack(Stack& stack, size_t num_outputs) {
auto list = pop(stack).toList();
TORCH_CHECK(
list.size() == num_outputs,
"Expected ",
num_outputs,
" elements in a list but found ",
list.size());
stack.insert(stack.end(), list.begin(), list.end());
}
void tupleConstruct(Stack& stack, size_t num_inputs) {
if (num_inputs > stack.size()) {
TORCH_CHECK(false, "Invalid number of inputs: ", num_inputs);
}
switch (num_inputs) {
case 0:
stack.emplace_back(c10::ivalue::Tuple::create());
break;
case 1:
stack.back() = c10::ivalue::Tuple::create(std::move(stack.back()));
break;
case 2: {
auto tuple = c10::ivalue::Tuple::create(
std::move(stack[stack.size() - 2]),
std::move(stack[stack.size() - 1]));
stack.pop_back();
stack.back() = std::move(tuple);
break;
}
case 3: {
auto tuple = c10::ivalue::Tuple::create(
std::move(stack[stack.size() - 3]),
std::move(stack[stack.size() - 2]),
std::move(stack[stack.size() - 1]));
stack.pop_back();
stack.pop_back();
stack.back() = std::move(tuple);
break;
}
default: {
std::vector<IValue> elems{
std::make_move_iterator(stack.end() - num_inputs),
std::make_move_iterator(stack.end())};
drop(stack, num_inputs - 1);
stack.back() = c10::ivalue::Tuple::create(std::move(elems));
break;
}
}
}
void namedTupleConstruct(
Stack& stack,
c10::TypePtr tuple_type,
size_t num_inputs) {
std::vector<IValue> elems{
std::make_move_iterator(stack.end() - num_inputs),
std::make_move_iterator(stack.end())};
drop(stack, num_inputs);
push(
stack,
c10::ivalue::Tuple::createNamed(std::move(elems), std::move(tuple_type)));
}
void listConstruct(
Stack& stack,
const c10::Type& list_type,
size_t num_inputs) {
// Structuring the implementation this way allows NRVO to avoid
// move-constructing vals on its way onto the stack. Moving a List
// isn't free.
auto makeList =
[](Stack& stack, const c10::Type& list_type, size_t num_inputs) {
c10::List<IValue> vals(list_type.containedType(0));
vals.reserve(num_inputs);
for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) {
vals.push_back(std::move(stack[i]));
}
drop(stack, num_inputs);
return vals;
};
stack.emplace_back(makeList(stack, list_type, num_inputs));
}
void dictConstruct(
Stack& stack,
const c10::Type& dict_type,
size_t num_inputs) {
auto vals = c10::impl::GenericDict(
dict_type.containedType(0), dict_type.containedType(1));
vals.reserve(num_inputs / 2);
// loop from the bottom of the stack to ensure the dictConstruct preserve
// the inputs order.
auto inputs = last(stack, num_inputs);
for (size_t i = 0; i < num_inputs; i += 2) {
auto key = inputs[i];
auto val = inputs[i + 1];
vals.insert_or_assign(std::move(key), std::move(val));
}
drop(stack, num_inputs);
push(stack, std::move(vals));
}
void createObject(
Stack& stack,
const at::ClassTypePtr& type,
bool as_weak_ref) {
if (as_weak_ref) {
c10::WeakTypePtr weak(type->compilation_unit(), type);
auto userObj = c10::ivalue::Object::create(
c10::WeakOrStrongTypePtr(weak), type->numAttributes());
push(stack, std::move(userObj));
} else {
auto userObj = c10::ivalue::Object::create(
c10::StrongTypePtr(type->compilation_unit(), type),
type->numAttributes());
push(stack, std::move(userObj));
}
}
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
at::TypePtr ty = pop(stack).type();
for (const at::TypePtr& candidate : types) {
if (ty->isSubtypeOf(*candidate)) {
push(stack, true);
return;
}
}
push(stack, false);
}
void tupleSlice(Stack& stack, size_t begin, size_t end) {
auto tuple = pop(stack).toTuple();
push(
stack,
c10::ivalue::Tuple::create(
tuple->elements().asArrayRef().slice(begin, end - begin)));
}
void dequantize(Stack& stack) {
auto iv = pop(stack);
if (iv.isTuple()) {
auto tuple = iv.toTuple();
const auto& elems = tuple->elements();
std::vector<IValue> output_elems;
output_elems.reserve(elems.size());
for (const auto& elem : elems) {
if (elem.isTensor()) {
output_elems.emplace_back(at::dequantize(elem.toTensor()));
} else {
output_elems.emplace_back(elem);
}
}
push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
} else if (iv.isTensorList()) {
auto elems = iv.toTensorList();
auto output_list = c10::impl::GenericList(elems.elementType());
for (auto&& elem : elems) {
output_list.emplace_back(at::dequantize(elem));
}
push(stack, std::move(output_list));
} else {
TORCH_CHECK(
false,
"Unsupported type in dequantize, only List[Tensor] and \
Tuple[Tensor or other types] are supported, got type:",
toString(iv.type()));
}
}
} // namespace torch::jit