mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: First batch of https://github.com/pytorch/pytorch/issues/20769, handles `isupper`, `islower`, `isdigit`, `isspace`, `isalnum`, `isalpha`, `upper`, `lower` Pull Request resolved: https://github.com/pytorch/pytorch/pull/20826 Differential Revision: D15466986 Pulled By: eellison fbshipit-source-id: d1df65721da803dfa30e28fdd9b874405be6bc7d
2681 lines
91 KiB
C++
2681 lines
91 KiB
C++
#include <aten/src/ATen/Context.h>
|
|
#include <torch/csrc/autograd/edge.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/custom_operator.h>
|
|
#include <torch/csrc/jit/fuser/interface.h>
|
|
#include <torch/csrc/jit/graph_executor.h>
|
|
#include <torch/csrc/jit/ir.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
#include <torch/csrc/jit/pickler.h>
|
|
#include <torch/csrc/jit/profiling_record.h>
|
|
#include <torch/csrc/jit/script/compilation_unit.h>
|
|
#include <torch/csrc/jit/script/error_report.h>
|
|
#include <torch/csrc/jit/script/jit_exception.h>
|
|
#include <torch/csrc/jit/script/logging.h>
|
|
|
|
#include <ATen/ExpandUtils.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/WrapDimUtils.h>
|
|
#include <ATen/core/Dict.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <c10/core/thread_pool.h>
|
|
#include <c10/util/SmallVector.h>
|
|
|
|
#include <cctype>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <exception>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <ostream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <typeinfo>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
|
|
Operation noop(const Node* n) {
|
|
return [](Stack& stack) { return 0; };
|
|
}
|
|
|
|
// using the rules from python_arg_parser FunctionParameter::check
|
|
// tensor cannot have grad set, tensor must be 0 dim,
|
|
// and if the dest is an int the source must be integral type
|
|
void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
|
|
if (autograd::as_variable_ref(t).requires_grad()) {
|
|
throw std::runtime_error(
|
|
"Cannot input a tensor that requires grad as a scalar argument");
|
|
}
|
|
if (t.sizes().size() != 0) {
|
|
throw std::runtime_error(
|
|
"Cannot input a tensor of dimension other than 0 as a scalar argument");
|
|
}
|
|
if (toInt &&
|
|
!isIntegralType(autograd::as_variable_ref(t).data().scalar_type())) {
|
|
std::stringstream ss;
|
|
ss << "Cannot input a tensor of type " << t.scalar_type()
|
|
<< " as an integral argument";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
|
|
template <typename dtype> // int64_t, bool, double
|
|
Operation listConstruct(int64_t num_inputs) {
|
|
return [=](Stack& stack) {
|
|
auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
|
|
std::vector<dtype> vals =
|
|
fmap(inputs, [](const IValue& v) { return v.to<dtype>(); });
|
|
drop(stack, num_inputs);
|
|
push(stack, std::move(vals));
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
static int64_t floordiv(int64_t a, int64_t b) {
|
|
if (b == 0) {
|
|
throw std::runtime_error("division by 0");
|
|
}
|
|
if ((a > 0) == (b > 0)) {
|
|
// simple case, both have same sign
|
|
return a / b;
|
|
} else {
|
|
// in python division rounds down, it doesnt not truncate like in c++
|
|
auto r = lldiv(a, b);
|
|
return (r.rem) ? r.quot - 1 : r.quot;
|
|
}
|
|
}
|
|
|
|
static int gcd(int a, int b) {
|
|
while (b != 0) {
|
|
int r = a % b;
|
|
a = b;
|
|
b = r;
|
|
}
|
|
// in python gcd returns non-negative values
|
|
return std::abs(a);
|
|
}
|
|
|
|
// reference function THPVariable_to in python_variable_methods.cpp
|
|
static at::Tensor to_dispatch(
|
|
at::Tensor self,
|
|
c10::optional<at::Device> device,
|
|
c10::optional<at::ScalarType> scalarType,
|
|
bool non_blocking,
|
|
bool copy) {
|
|
if (device && device->is_cuda()) {
|
|
at::globalContext().lazyInitCUDA();
|
|
}
|
|
if (!device && !scalarType && !copy) {
|
|
return self;
|
|
} else if (!device) {
|
|
return self.to(*scalarType, non_blocking, copy);
|
|
} else if (!scalarType) {
|
|
return self.to(*device, non_blocking, copy);
|
|
} else {
|
|
return self.to(*device, *scalarType, non_blocking, copy);
|
|
}
|
|
}
|
|
|
|
// Convert an python index (which may be negative) into an index usable for a
|
|
// C++ container
|
|
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
|
|
if (idx < 0) {
|
|
// Handle negative indexing
|
|
idx = list_size + idx;
|
|
}
|
|
return idx;
|
|
}
|
|
|
|
RegisterOperators reg(
|
|
{Operator(
|
|
prim::profile,
|
|
[](const Node* node) {
|
|
auto callback = node->cast<ProfileOp>()->getCallback();
|
|
return [callback](Stack& stack) {
|
|
callback(stack);
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::FusionGroup,
|
|
[](const Node* node) {
|
|
const auto key = registerFusion(node);
|
|
return [key](Stack& stack) {
|
|
RECORD_FUNCTION("FusionGroup", std::vector<c10::IValue>());
|
|
runFusion(key, stack);
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"prim::Guard(Tensor(a) t) -> Tensor(a)",
|
|
[](const Node* node) {
|
|
return [](Stack& stack) {
|
|
AT_ERROR("Should be replaced by prim::BailOut");
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"prim::rangelist(int n) -> int[]",
|
|
[](Stack& stack) {
|
|
int64_t n;
|
|
pop(stack, n);
|
|
std::vector<int64_t> elems(n);
|
|
for (int i = 0; i < n; i++) {
|
|
elems[i] = i;
|
|
}
|
|
push(stack, jit::IntList::create(elems));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Bool(Tensor a) -> bool",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.is_nonzero());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Bool(int a) -> bool",
|
|
[](Stack& stack) {
|
|
int64_t i;
|
|
pop(stack, i);
|
|
push(stack, (bool)i);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Bool(float a) -> bool",
|
|
[](Stack& stack) {
|
|
double d;
|
|
pop(stack, d);
|
|
push(stack, (bool)d);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Int(Tensor a) -> int",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.item<int64_t>());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Float(Tensor a) -> float",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.item<double>());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::ImplicitTensorToNum(Tensor a) -> Scalar",
|
|
[](const Node* node) -> Operation {
|
|
if (node->output()->type() == IntType::get()) {
|
|
return [](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
checkImplicitTensorToNum(a, /*to int*/ true);
|
|
push(stack, a.item<int64_t>());
|
|
return 0;
|
|
};
|
|
} else {
|
|
return [](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
checkImplicitTensorToNum(a, /*to int*/ false);
|
|
push(stack, a.item<double>());
|
|
return 0;
|
|
};
|
|
}
|
|
}),
|
|
Operator(
|
|
"prim::NumToTensor(Scalar a) -> Tensor",
|
|
[](Stack& stack) {
|
|
at::Scalar s;
|
|
pop(stack, s);
|
|
push(stack, autograd::make_variable(at::scalar_to_tensor(s)));
|
|
return 0;
|
|
}),
|
|
// note: this op needs to share a name with the Scalar -> Tensor conversion
|
|
// because all _to_tensor conversion have to have the same operator namet
|
|
Operator(
|
|
"prim::NumToTensor(bool a) -> Tensor",
|
|
[](Stack& stack) {
|
|
bool b;
|
|
pop(stack, b);
|
|
push(stack, autograd::make_variable(at::scalar_to_tensor(b)));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Float(Scalar a) -> float",
|
|
[](Stack& stack) {
|
|
IValue scalar;
|
|
pop(stack, scalar);
|
|
if (scalar.isDouble()) {
|
|
push(stack, scalar);
|
|
} else {
|
|
push(stack, static_cast<double>(scalar.toInt()));
|
|
}
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Float(int a) -> float",
|
|
[](Stack& stack) {
|
|
int64_t i;
|
|
pop(stack, i);
|
|
push(stack, (float)i);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Int(float a) -> int",
|
|
[](Stack& stack) {
|
|
double d;
|
|
pop(stack, d);
|
|
push(stack, (int64_t)d);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Float(bool a) -> float",
|
|
[](Stack& stack) {
|
|
bool b;
|
|
pop(stack, b);
|
|
push(stack, (float)b);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Int(bool a) -> int",
|
|
[](Stack& stack) {
|
|
bool b;
|
|
pop(stack, b);
|
|
push(stack, (int)b);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Int(Scalar a) -> int",
|
|
[](Stack& stack) {
|
|
IValue scalar;
|
|
pop(stack, scalar);
|
|
if (scalar.isInt()) {
|
|
push(stack, scalar);
|
|
} else {
|
|
push(stack, static_cast<int64_t>(scalar.toDouble()));
|
|
}
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::Float(str a) -> float",
|
|
[](Stack& stack) {
|
|
auto s = pop(stack).toString();
|
|
if (s->string() == "inf")
|
|
push(stack, std::numeric_limits<double>::infinity());
|
|
else if (s->string() == "-inf")
|
|
push(stack, -std::numeric_limits<double>::infinity());
|
|
else
|
|
AT_ERROR(
|
|
"Only 'inf' or '-inf' can be cast to a float, but got '",
|
|
s->string(),
|
|
"'");
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::device(str a) -> Device",
|
|
[](Stack& stack) {
|
|
push(stack, c10::Device(pop(stack).toStringRef()));
|
|
return 0;
|
|
}),
|
|
// reference function parse_to_conversion in python_arg_parsing.h
|
|
Operator(
|
|
"aten::to(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
|
|
[](Stack& stack) {
|
|
bool non_blocking;
|
|
bool copy;
|
|
pop(stack, non_blocking, copy);
|
|
c10::optional<at::ScalarType> scalarType =
|
|
pop(stack).toOptional<at::ScalarType>();
|
|
c10::optional<c10::Device> device =
|
|
pop(stack).toOptional<c10::Device>();
|
|
at::Tensor self = pop(stack).toTensor();
|
|
push(
|
|
stack,
|
|
to_dispatch(self, device, scalarType, non_blocking, copy));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::to(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
|
|
[](Stack& stack) {
|
|
bool non_blocking;
|
|
bool copy;
|
|
pop(stack, non_blocking, copy);
|
|
c10::optional<at::ScalarType> scalarType =
|
|
pop(stack).toOptional<at::ScalarType>();
|
|
c10::optional<c10::Device> device = c10::nullopt;
|
|
at::Tensor self = pop(stack).toTensor();
|
|
push(
|
|
stack,
|
|
to_dispatch(self, device, scalarType, non_blocking, copy));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::to(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
|
|
[](Stack& stack) {
|
|
at::Tensor self;
|
|
bool non_blocking;
|
|
bool copy;
|
|
pop(stack, self, non_blocking, copy);
|
|
c10::optional<c10::Device> device = c10::nullopt;
|
|
c10::optional<at::ScalarType> scalarType = c10::nullopt;
|
|
push(
|
|
stack,
|
|
to_dispatch(self, device, scalarType, non_blocking, copy));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::eq(Device a, Device b) -> bool",
|
|
[](Stack& stack) {
|
|
auto a = pop(stack).toDevice();
|
|
auto b = pop(stack).toDevice();
|
|
push(stack, a == b);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::device(Tensor a) -> Device",
|
|
[](Stack& stack) {
|
|
push(stack, pop(stack).toTensor().device());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::dtype(Tensor a) -> int",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, static_cast<int64_t>(a.scalar_type()));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::requires_grad(Tensor a) -> bool",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.requires_grad());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::shape(Tensor a) -> int[]",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.sizes());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::is_cuda(Tensor a) -> bool",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.is_cuda());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::cpu(Tensor(a) self) -> Tensor(a|b)",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.cpu());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
// TODO return generator object when torchscript supports RNG
|
|
// first-class
|
|
"aten::manual_seed(int seed) -> ()",
|
|
[](Stack& stack) {
|
|
at::manual_seed(pop(stack).toInt());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::cuda(Tensor(a) self) -> Tensor(a|b)",
|
|
[](Stack& stack) {
|
|
at::Tensor a;
|
|
pop(stack, a);
|
|
push(stack, a.cuda());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::AutogradZero() -> Tensor",
|
|
[](const Node* node) {
|
|
return [](Stack& stack) {
|
|
stack.emplace_back(at::Tensor());
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"aten::save(t item, str filename) -> ()",
|
|
[](Stack& stack) {
|
|
auto filename = pop(stack).toStringRef();
|
|
auto value = pop(stack);
|
|
|
|
// Pickle the tensor
|
|
Pickler p;
|
|
p.pushMetadata();
|
|
p.start();
|
|
p.addIValue(value);
|
|
p.finish();
|
|
|
|
// Write file
|
|
std::fstream output(filename, std::ios::out | std::ios::binary);
|
|
output.write(p.stack().data(), p.stack().size());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
prim::Print,
|
|
[](const Node* node) {
|
|
size_t num_inputs = node->inputs().size();
|
|
return [num_inputs](Stack& stack) {
|
|
bool first = true;
|
|
for (const IValue& i : last(stack, num_inputs)) {
|
|
if (!first)
|
|
std::cout << " ";
|
|
first = false;
|
|
std::cout << i;
|
|
}
|
|
drop(stack, num_inputs);
|
|
std::cout << std::endl;
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::BroadcastSizes,
|
|
[](const Node* node) -> Operation {
|
|
size_t num_inputs = node->inputs().size();
|
|
return [num_inputs](Stack& stack) {
|
|
std::vector<int64_t> size;
|
|
size.reserve(8);
|
|
for (size_t i = 0; i < num_inputs; ++i) {
|
|
size = at::infer_size(
|
|
size, peek(stack, i, num_inputs).toIntList()->elements());
|
|
}
|
|
drop(stack, num_inputs);
|
|
push(stack, std::move(size));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::ChunkSizes,
|
|
[](const Node* node) -> Operation {
|
|
int64_t raw_dim = node->i(attr::dim);
|
|
int64_t chunks = node->i(attr::chunks);
|
|
return [raw_dim, chunks](Stack& stack) {
|
|
Shared<IntList> sizes_l;
|
|
pop(stack, sizes_l);
|
|
const auto& shape = sizes_l->elements();
|
|
std::vector<int64_t> regular_shape = shape;
|
|
std::vector<int64_t> last_shape = shape;
|
|
int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
|
|
TORCH_CHECK(
|
|
dim < (int64_t)regular_shape.size(),
|
|
"Dimension out of range for chunk");
|
|
int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
|
|
regular_shape[dim] = split_size;
|
|
if (shape[dim] % chunks == 0) {
|
|
last_shape[dim] = split_size;
|
|
} else {
|
|
int64_t num_splits = std::max<int64_t>(
|
|
(shape[dim] + split_size - 1) / split_size, 1);
|
|
last_shape[dim] =
|
|
split_size - (split_size * num_splits - shape[dim]);
|
|
AT_ASSERT(last_shape[dim] >= 0);
|
|
}
|
|
push(stack, std::move(regular_shape));
|
|
push(stack, std::move(last_shape));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
FunctionSchema(
|
|
"aten::warn",
|
|
"",
|
|
{Argument("message", StringType::get()),
|
|
Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)},
|
|
{}),
|
|
[](const Node* node) {
|
|
return [](Stack& stack) {
|
|
drop(stack, 1);
|
|
AT_WARN(pop(stack).toStringRef());
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"prim::RaiseException(str msg) -> ()",
|
|
[](Stack& stack) {
|
|
throw JITException(pop(stack).toStringRef());
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"prim::IgnoredPythonOp(...) -> None",
|
|
[](Stack& stack) {
|
|
throw JITException(
|
|
"This Python function is annotated to be ignored"
|
|
" and cannot be and has not been included in the exported"
|
|
" binary, meaning that it cannot be executed now."
|
|
" Make sure that ignored operations are never executed after"
|
|
" import");
|
|
return 0;
|
|
}),
|
|
|
|
// Load x, y
|
|
// loads values from registers onto the stack, the actual callback does
|
|
// nothing since the stack manipulation is already encoded in inst.inputs
|
|
// and inst.outputs
|
|
Operator(prim::Load, noop),
|
|
// x, y = Store
|
|
// stores vales from stack into registers, the actual callback does
|
|
// nothing since the stack manipulation is already encoded in inst.inputs
|
|
// and inst.outputs
|
|
Operator(prim::Store, noop),
|
|
Operator(
|
|
prim::Drop,
|
|
[](const Node* node) {
|
|
auto N = node->inputs().size();
|
|
return [=](Stack& stack) {
|
|
drop(stack, N);
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
c10::onnx::Reshape,
|
|
[](const Node* node) {
|
|
return [=](Stack& stack) {
|
|
at::Tensor input, shape;
|
|
pop(stack, input, shape);
|
|
shape = shape.contiguous();
|
|
AT_ASSERT(shape.ndimension() == 1);
|
|
at::IntArrayRef shape_list(shape.data<int64_t>(), shape.size(0));
|
|
push(stack, input.reshape(shape_list));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
c10::onnx::Shape,
|
|
[](const Node* node) {
|
|
return [=](Stack& stack) {
|
|
auto t = pop(stack).toTensor();
|
|
at::IntArrayRef sizes = t.sizes();
|
|
auto sizes_tensor = torch::empty(
|
|
{static_cast<int64_t>(sizes.size())}, at::dtype(at::kLong));
|
|
auto accessor = sizes_tensor.accessor<int64_t, 1>();
|
|
for (size_t i = 0; i < sizes.size(); ++i) {
|
|
accessor[i] = sizes[i];
|
|
}
|
|
stack.emplace_back(sizes_tensor);
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::AutogradAnyNonZero,
|
|
[](const Node* node) {
|
|
size_t num_inputs = node->inputs().size();
|
|
return [=](Stack& stack) {
|
|
bool result = false;
|
|
for (const IValue& t : last(stack, num_inputs)) {
|
|
if (t.toTensor().defined()) {
|
|
result = true;
|
|
break;
|
|
}
|
|
}
|
|
drop(stack, num_inputs);
|
|
stack.emplace_back(result);
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::AutogradAdd,
|
|
[](const Node* node) {
|
|
return [=](Stack& stack) {
|
|
at::Tensor a, b;
|
|
pop(stack, a, b);
|
|
if (!a.defined())
|
|
stack.emplace_back(b);
|
|
else if (!b.defined())
|
|
stack.emplace_back(a);
|
|
else
|
|
stack.emplace_back(a + b);
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
|
|
[](Stack& stack) {
|
|
at::Tensor self;
|
|
Shared<IntList> desired_sizes;
|
|
pop(stack, self, desired_sizes);
|
|
push(stack, at::sum_to(std::move(self), desired_sizes->elements()));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
prim::TupleUnpack,
|
|
[](const Node* node) {
|
|
size_t num_elems = node->outputs().size();
|
|
return [=](Stack& stack) {
|
|
auto t = pop(stack).toTuple();
|
|
const auto& elems = t->elements();
|
|
if (elems.size() != num_elems) {
|
|
AT_ERROR(
|
|
"Expected a tuple of ",
|
|
num_elems,
|
|
" elements, but got ",
|
|
elems.size());
|
|
}
|
|
stack.insert(stack.end(), elems.begin(), elems.end());
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::TupleSlice,
|
|
[](const Node* node) {
|
|
int64_t beg_ind = node->i(attr::beg);
|
|
int64_t end_ind = node->i(attr::end);
|
|
return [=](Stack& stack) {
|
|
auto t = pop(stack).toTuple();
|
|
const auto& elems = t->elements();
|
|
std::vector<IValue> output_elems;
|
|
for (int64_t i = beg_ind; i < end_ind; ++i) {
|
|
output_elems.emplace_back(elems.at(i));
|
|
}
|
|
push(stack, Tuple::create(std::move(output_elems)));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::TupleIndex,
|
|
[](const Node* node) {
|
|
return [](Stack& stack) {
|
|
int64_t index = pop(stack).toInt();
|
|
auto tup = pop(stack).toTuple();
|
|
const auto& elems = tup->elements();
|
|
auto norm_index = normalizeIndex(index, elems.size());
|
|
if (norm_index < 0 ||
|
|
norm_index > static_cast<int64_t>(elems.size())) {
|
|
throw std::out_of_range("Tuple list index out of range");
|
|
}
|
|
stack.emplace_back(elems.at(norm_index));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::TupleConstruct,
|
|
[](const Node* node) {
|
|
size_t num_inputs = node->inputs().size();
|
|
return [=](Stack& stack) {
|
|
std::vector<IValue> elems{
|
|
std::make_move_iterator(stack.end() - num_inputs),
|
|
std::make_move_iterator(stack.end())};
|
|
drop(stack, num_inputs);
|
|
push(stack, Tuple::create(std::move(elems)));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::ConstantChunk,
|
|
[](const Node* node) {
|
|
int64_t chunks = node->i(attr::chunks);
|
|
int64_t dim = node->i(attr::dim);
|
|
auto outputs_used = fmap(node->outputs(), [](const Value* v) {
|
|
return v->uses().size() > 0;
|
|
});
|
|
return [=](Stack& stack) {
|
|
RECORD_FUNCTION("chunk", last(stack, 1));
|
|
|
|
at::Tensor t;
|
|
pop(stack, t);
|
|
auto result = at::chunk(t, chunks, dim);
|
|
stack.insert(
|
|
stack.end(),
|
|
std::make_move_iterator(result.begin()),
|
|
std::make_move_iterator(result.end()));
|
|
// NB: Chunk can sometimes return a smaller number of outputs.
|
|
int64_t num_results = result.size();
|
|
if (num_results != chunks) {
|
|
if (num_results > chunks) {
|
|
TORCH_CHECK(
|
|
num_results == chunks,
|
|
"Expected chunk to return ",
|
|
chunks,
|
|
" outputs, but got ",
|
|
num_results);
|
|
}
|
|
for (int64_t i = num_results; i < chunks; ++i) {
|
|
TORCH_CHECK(
|
|
!outputs_used[i],
|
|
"Expected chunk to return at least ",
|
|
chunks,
|
|
" outputs, but got only ",
|
|
num_results);
|
|
// We know that the output is unused, so it's ok to push
|
|
// anything on the stack.
|
|
stack.emplace_back();
|
|
}
|
|
}
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::ListUnpack,
|
|
[](const Node* node) -> Operation {
|
|
const auto num_outputs = node->outputs().size();
|
|
ListTypePtr lt = node->input()->type()->expect<ListType>();
|
|
if (lt->getElementType() == IntType::get()) {
|
|
return [=](Stack& stack) {
|
|
auto ilist = pop(stack);
|
|
const auto& list = ilist.toIntList()->elements();
|
|
TORCH_CHECK(
|
|
list.size() == num_outputs,
|
|
"Expected ",
|
|
num_outputs,
|
|
" elements in a list but found ",
|
|
list.size());
|
|
stack.insert(stack.end(), list.begin(), list.end());
|
|
return 0;
|
|
};
|
|
} else if (lt->getElementType() == FloatType::get()) {
|
|
return [=](Stack& stack) {
|
|
auto ilist = pop(stack);
|
|
const auto& list = ilist.toDoubleList()->elements();
|
|
TORCH_CHECK(
|
|
list.size() == num_outputs,
|
|
"Expected ",
|
|
num_outputs,
|
|
" elements in a list but found ",
|
|
list.size());
|
|
stack.insert(stack.end(), list.begin(), list.end());
|
|
return 0;
|
|
};
|
|
} else if (lt->getElementType() == TensorType::get()) {
|
|
return [=](Stack& stack) {
|
|
auto ilist = pop(stack);
|
|
const auto& list = ilist.toTensorList()->elements();
|
|
TORCH_CHECK(
|
|
list.size() == num_outputs,
|
|
"Expected ",
|
|
num_outputs,
|
|
" elements in a list but found ",
|
|
list.size());
|
|
stack.insert(stack.end(), list.begin(), list.end());
|
|
return 0;
|
|
};
|
|
} else {
|
|
return [=](Stack& stack) {
|
|
auto glist = pop(stack);
|
|
const auto& list = glist.toGenericList()->elements();
|
|
TORCH_CHECK(
|
|
list.size() == num_outputs,
|
|
"Expected ",
|
|
num_outputs,
|
|
" elements in a list but found ",
|
|
list.size());
|
|
stack.insert(stack.end(), list.begin(), list.end());
|
|
return 0;
|
|
};
|
|
}
|
|
}),
|
|
Operator(
|
|
prim::ListConstruct,
|
|
[](const Node* node) -> Operation {
|
|
const auto num_inputs = node->inputs().size();
|
|
ListTypePtr lt = node->output()->type()->expect<ListType>();
|
|
if (IntType::get() == lt->getElementType()) {
|
|
return listConstruct<int64_t>(num_inputs);
|
|
} else if (FloatType::get() == lt->getElementType()) {
|
|
return listConstruct<double>(num_inputs);
|
|
} else if (lt->getElementType() == BoolType::get()) {
|
|
return listConstruct<bool>(num_inputs);
|
|
} else if (lt->getElementType()->isSubtypeOf(TensorType::get())) {
|
|
return [=](Stack& stack) {
|
|
const size_t stack_size = stack.size();
|
|
std::vector<at::Tensor> vals;
|
|
vals.reserve(num_inputs);
|
|
for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
|
|
vals.emplace_back(std::move(stack[i]).toTensor());
|
|
}
|
|
drop(stack, num_inputs);
|
|
push(stack, std::move(vals));
|
|
return 0;
|
|
};
|
|
} else {
|
|
return [=](Stack& stack) {
|
|
const size_t stack_size = stack.size();
|
|
std::vector<IValue> vals;
|
|
vals.reserve(num_inputs);
|
|
for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
|
|
vals.emplace_back(std::move(stack[i]));
|
|
}
|
|
drop(stack, num_inputs);
|
|
push(stack, std::move(vals));
|
|
return 0;
|
|
};
|
|
}
|
|
}),
|
|
Operator(
|
|
prim::DictConstruct,
|
|
[](const Node* node) -> Operation {
|
|
const auto num_inputs = node->inputs().size();
|
|
if (num_inputs % 2 != 0) {
|
|
throw std::runtime_error(
|
|
"DictConstruct must have an even number of inputs");
|
|
}
|
|
return [=](Stack& stack) {
|
|
c10::impl::GenericDict vals;
|
|
for (size_t i = 0; i < num_inputs; i += 2) {
|
|
auto val = pop(stack);
|
|
auto key = pop(stack);
|
|
vals.insert_or_assign(std::move(key), std::move(val));
|
|
}
|
|
push(stack, std::move(vals));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"aten::_unwrap_optional(t(a)? optional) -> t(a)",
|
|
[](Stack& stack) {
|
|
auto val = pop(stack);
|
|
TORCH_CHECK(!val.isNone(), "Unwrapping null optional");
|
|
push(stack, val);
|
|
return 0;
|
|
}),
|
|
// This op can be removed in preprocessing before being run in the
|
|
// interpreter (but is currently not removed), even when it is removed it
|
|
// needs to remain a registered op so that constant prop can run.
|
|
Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop),
|
|
Operator(
|
|
prim::fork,
|
|
[](const Node* node) {
|
|
Code code(node->g(attr::Subgraph));
|
|
int n_inputs = node->inputs().size();
|
|
AT_ASSERT(node->blocks().size() == 0);
|
|
AT_ASSERT(node->hasAttribute(attr::Subgraph));
|
|
return [=](Stack& stack) {
|
|
// Move inputs to a separate stack
|
|
InterpreterState forked_interprester(code);
|
|
InterpreterContinuation continuation(
|
|
forked_interprester,
|
|
Stack(stack.end() - n_inputs, stack.end()),
|
|
autograd::GradMode::is_enabled());
|
|
drop(stack, n_inputs);
|
|
|
|
push(stack, forked_interprester.getFuture());
|
|
|
|
at::launch(std::move(continuation));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
"aten::wait(Future(t) self) -> t",
|
|
[](Stack& stack) {
|
|
auto future = pop(stack).toFuture();
|
|
if (future->completed()) {
|
|
push(stack, future->value());
|
|
} else {
|
|
throw Suspend(future);
|
|
}
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
prim::CreateObject,
|
|
[](const Node* node) {
|
|
const auto type = node->output()->type()->expect<ClassType>();
|
|
const size_t numAttrs = type->numAttributes();
|
|
return [type, numAttrs](Stack& stack) {
|
|
auto userObj = c10::ivalue::Object::create(type, numAttrs);
|
|
push(stack, std::move(userObj));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(
|
|
prim::GetAttr,
|
|
[](const Node* node) {
|
|
const auto type = node->input()->type()->expect<ClassType>();
|
|
const auto& field = node->s(attr::name);
|
|
const auto slot = type->getAttributeSlot(field);
|
|
return [slot](Stack& stack) {
|
|
auto userObj = pop(stack).toObject();
|
|
auto value = userObj->getSlot(slot);
|
|
push(stack, std::move(value));
|
|
return 0;
|
|
};
|
|
}),
|
|
Operator(prim::SetAttr, [](const Node* node) {
|
|
const auto type = node->inputs().at(0)->type()->expect<ClassType>();
|
|
const auto& field = node->s(attr::name);
|
|
const auto slot = type->getAttributeSlot(field);
|
|
return [slot](Stack& stack) {
|
|
auto v = pop(stack);
|
|
auto userObj = pop(stack).toObject();
|
|
userObj->setSlot(slot, std::move(v));
|
|
return 0;
|
|
};
|
|
})});
|
|
|
|
RegisterOperators logging_operators(
|
|
{Operator(
|
|
"prim::AddStatValue(str key, int val) -> ()",
|
|
[](Stack& stack) {
|
|
auto val = pop(stack).toInt();
|
|
auto key = pop(stack).toString();
|
|
|
|
auto schema =
|
|
parseSchema("prim::AddStatValue(str key, int val) -> ()");
|
|
// TODO: remove this custom tracing code once the custom op bugfix
|
|
// lands
|
|
if (jit::tracer::isTracing()) {
|
|
const auto& graph = tracer::getTracingState()->graph;
|
|
Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
|
|
tracer::recordSourceLocation(node);
|
|
node->addInput(insertConstant(*graph, key));
|
|
tracer::addInputs(node, "val", val);
|
|
graph->insertNode(node);
|
|
}
|
|
torch::jit::logging::getLogger()->addStatValue(*key, val);
|
|
return 0;
|
|
}),
|
|
Operator("prim::TimePoint() -> int", [](Stack& stack) {
|
|
auto schema = parseSchema("prim::TimePoint() -> int");
|
|
Node* node = nullptr;
|
|
// TODO: remove this custom tracing code once the custom op bugfix lands
|
|
if (jit::tracer::isTracing()) {
|
|
const auto& graph = tracer::getTracingState()->graph;
|
|
Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
|
|
tracer::recordSourceLocation(node);
|
|
graph->insertNode(node);
|
|
}
|
|
auto output = autograd::profiler::getTime();
|
|
push(stack, output);
|
|
if (jit::tracer::isTracing()) {
|
|
jit::tracer::addOutput(node, output);
|
|
}
|
|
return 0;
|
|
})});
|
|
|
|
// define implementations for primitive number ops
|
|
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
|
|
Operator( \
|
|
#aten_op "(int a, int b) -> " #int_result, \
|
|
[](Stack& stack) { \
|
|
int64_t a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, int_op); \
|
|
return 0; \
|
|
}), \
|
|
Operator( \
|
|
#aten_op "(float a, float b) -> " #float_result, [](Stack& stack) { \
|
|
double a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, float_op); \
|
|
return 0; \
|
|
})
|
|
|
|
#define DEFINE_INT_FLOAT_OP(aten_op, op, result) \
|
|
Operator( \
|
|
#aten_op "(int a, float b) -> " #result, \
|
|
[](Stack& stack) { \
|
|
int64_t a; \
|
|
double b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
return 0; \
|
|
}), \
|
|
Operator(#aten_op "(float a, int b) -> " #result, [](Stack& stack) { \
|
|
double a; \
|
|
int64_t b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
return 0; \
|
|
})
|
|
|
|
#define DEFINE_INT_OP(aten_op, op) \
|
|
Operator(#aten_op "(int a, int b) -> int", [](Stack& stack) { \
|
|
int64_t a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \
|
|
return 0; \
|
|
})
|
|
|
|
#define DEFINE_STR_CMP_OP(aten_op, op) \
|
|
Operator(#aten_op "(str a, str b) -> bool", [](Stack& stack) { \
|
|
auto b = pop(stack).toStringRef(); \
|
|
auto a = pop(stack).toStringRef(); \
|
|
push(stack, op); \
|
|
return 0; \
|
|
})
|
|
|
|
#define DEFINE_BINARY_OP(aten_op, op) \
|
|
DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, float)
|
|
#define DEFINE_COMPARISON_OP(aten_op, op) \
|
|
DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
|
|
DEFINE_INT_FLOAT_OP(aten_op, op, bool), DEFINE_STR_CMP_OP(aten_op, op)
|
|
|
|
#define DEFINE_BOOL_OP(aten_op, op) \
|
|
Operator(#aten_op "(bool a, bool b) -> bool", [](Stack& stack) { \
|
|
bool a, b; \
|
|
pop(stack, a, b); \
|
|
push(stack, op); \
|
|
return 0; \
|
|
})
|
|
|
|
int stringSlice(Stack& stack) {
|
|
auto step = pop(stack).toInt();
|
|
TORCH_CHECK(step == 1, "Slicing a string only supports step=1");
|
|
|
|
auto end = pop(stack).toInt();
|
|
auto start = pop(stack).toInt();
|
|
auto string = pop(stack).toStringRef();
|
|
const int64_t size = string.size();
|
|
|
|
// Clamp start and end to the bounds of the list
|
|
start = std::max(int64_t(0), normalizeIndex(start, size));
|
|
end = std::min(size, normalizeIndex(end, size));
|
|
|
|
if (end <= start) {
|
|
// Slice is empty
|
|
push(stack, std::string(""));
|
|
return 0;
|
|
}
|
|
|
|
std::string result(string.begin() + start, string.begin() + end);
|
|
push(stack, result);
|
|
return 0;
|
|
}
|
|
|
|
// Equivalent to list.at(idx)
|
|
template <typename TList> // something like Shared<IntList>
|
|
typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) {
|
|
const int64_t list_size = list->elements().size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
if (normalized_idx < 0 || normalized_idx >= list_size) {
|
|
throw std::out_of_range("list index out of range");
|
|
}
|
|
return list->elements()[normalized_idx];
|
|
}
|
|
|
|
// cannot return a reference to an element in a bool vector
|
|
bool getBoolItem(const std::vector<bool>& list, int64_t idx) {
|
|
const int64_t list_size = list.size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
if (normalized_idx < 0 || normalized_idx >= list_size) {
|
|
throw std::out_of_range("list index out of range");
|
|
}
|
|
return list[normalized_idx];
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listAppend(Stack& stack) {
|
|
TList a;
|
|
TElement el;
|
|
pop(stack, a, el);
|
|
|
|
a->elements().push_back(el);
|
|
push(stack, a);
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList>
|
|
int listReverse(Stack& stack) {
|
|
TList a;
|
|
pop(stack, a);
|
|
|
|
auto& elements = a->elements();
|
|
std::reverse(elements.begin(), elements.end());
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList>
|
|
int listPop(Stack& stack) {
|
|
TList list;
|
|
int64_t idx;
|
|
pop(stack, list, idx);
|
|
|
|
auto& elements = list->elements();
|
|
const int64_t list_size = elements.size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
|
|
if (list_size == 0) {
|
|
AT_ERROR("pop from empty list");
|
|
}
|
|
|
|
push(stack, std::move(getItem(list, idx)));
|
|
elements.erase(elements.begin() + normalized_idx);
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <>
|
|
int listPop<Shared<BoolList>>(Stack& stack) {
|
|
Shared<BoolList> list;
|
|
int64_t idx;
|
|
pop(stack, list, idx);
|
|
|
|
auto& elements = list->elements();
|
|
const int64_t list_size = elements.size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
|
|
if (list_size == 0) {
|
|
AT_ERROR("pop from empty list");
|
|
}
|
|
|
|
push(stack, getBoolItem(elements, idx));
|
|
elements.erase(elements.begin() + normalized_idx);
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList>
|
|
int listClear(Stack& stack) {
|
|
TList a;
|
|
pop(stack, a);
|
|
|
|
a->elements().clear();
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listInsert(Stack& stack) {
|
|
TList list;
|
|
int64_t idx;
|
|
TElement elem;
|
|
pop(stack, list, idx, elem);
|
|
|
|
auto& elements = list->elements();
|
|
const int64_t list_size = elements.size();
|
|
const int64_t normalized_idx = normalizeIndex(idx, list_size);
|
|
|
|
if (normalized_idx < 0 || normalized_idx >= list_size) {
|
|
if (normalized_idx < 0) {
|
|
elements.insert(elements.begin(), elem);
|
|
} else {
|
|
elements.push_back(elem);
|
|
}
|
|
} else {
|
|
elements.insert(elements.begin() + normalized_idx, elem);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listRemove(Stack& stack) {
|
|
TList list;
|
|
TElement elem;
|
|
pop(stack, list, elem);
|
|
|
|
auto& elements = list->elements();
|
|
auto pos = std::find(elements.begin(), elements.end(), elem);
|
|
|
|
if (pos != elements.end()) {
|
|
elements.erase(pos);
|
|
} else {
|
|
AT_ERROR("list.remove(x): x not in list");
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <>
|
|
int listRemove<Shared<TensorList>, at::Tensor>(Stack& stack) {
|
|
Shared<TensorList> list;
|
|
at::Tensor elem;
|
|
pop(stack, list, elem);
|
|
|
|
auto& elements = list->elements();
|
|
auto pos = std::find_if(
|
|
elements.begin(), elements.end(), [elem](const at::Tensor& b) {
|
|
const auto cmp_result = elem.eq(b);
|
|
return cmp_result.is_nonzero();
|
|
});
|
|
|
|
if (pos != elements.end()) {
|
|
elements.erase(pos);
|
|
} else {
|
|
AT_ERROR("list.remove(x): x not in list");
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listIndex(Stack& stack) {
|
|
TList list;
|
|
TElement elem;
|
|
pop(stack, list, elem);
|
|
|
|
auto& elements = list->elements();
|
|
auto pos = std::find(elements.begin(), elements.end(), elem);
|
|
|
|
if (pos != elements.end()) {
|
|
push(stack, static_cast<int64_t>(std::distance(elements.begin(), pos)));
|
|
} else {
|
|
AT_ERROR("'", elem, "' is not in list");
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <>
|
|
int listIndex<Shared<TensorList>, at::Tensor>(Stack& stack) {
|
|
Shared<TensorList> list;
|
|
at::Tensor elem;
|
|
pop(stack, list, elem);
|
|
|
|
auto& elements = list->elements();
|
|
auto pos = std::find_if(
|
|
elements.begin(), elements.end(), [elem](const at::Tensor& b) {
|
|
const auto cmp_result = elem.eq(b);
|
|
return cmp_result.is_nonzero();
|
|
});
|
|
|
|
if (pos != elements.end()) {
|
|
push(stack, static_cast<int64_t>(std::distance(elements.begin(), pos)));
|
|
} else {
|
|
AT_ERROR("'", elem, "' is not in list");
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listCount(Stack& stack) {
|
|
TList list;
|
|
TElement elem;
|
|
pop(stack, list, elem);
|
|
|
|
auto& elements = list->elements();
|
|
const int64_t count = std::count(elements.begin(), elements.end(), elem);
|
|
push(stack, count);
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <>
|
|
int listCount<Shared<TensorList>, at::Tensor>(Stack& stack) {
|
|
Shared<TensorList> list;
|
|
at::Tensor elem;
|
|
pop(stack, list, elem);
|
|
|
|
auto& elements = list->elements();
|
|
const int64_t count = std::count_if(
|
|
elements.begin(), elements.end(), [elem](const at::Tensor& b) {
|
|
const auto cmp_result = elem.eq(b);
|
|
return cmp_result.is_nonzero();
|
|
});
|
|
push(stack, count);
|
|
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList>
|
|
Operation listExtend(const Node* node) {
|
|
return [](Stack& stack) {
|
|
TList a;
|
|
TList b;
|
|
pop(stack, a, b);
|
|
|
|
auto& vec_a = a->elements();
|
|
const auto& vec_b = b->elements();
|
|
vec_a.insert(vec_a.end(), vec_b.cbegin(), vec_b.cend());
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
template <typename TList>
|
|
Operation listCopy(const Node* node) {
|
|
return [](Stack& stack) {
|
|
TList list;
|
|
pop(stack, list);
|
|
|
|
const auto& vec = list->elements();
|
|
auto out = vec;
|
|
push(stack, out);
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
template <typename T>
|
|
int listSelect(Stack& stack) {
|
|
T list;
|
|
int64_t idx;
|
|
pop(stack, list, idx);
|
|
|
|
auto element = getItem(list, idx);
|
|
push(stack, std::move(element));
|
|
return 0;
|
|
}
|
|
|
|
// needs specialization because cannot return a pointer to a bool in an array
|
|
template <>
|
|
int listSelect<Shared<BoolList>>(Stack& stack) {
|
|
Shared<BoolList> list;
|
|
int64_t idx;
|
|
pop(stack, list, idx);
|
|
|
|
auto element = getBoolItem(list->elements(), idx);
|
|
push(stack, element);
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int listLen(Stack& stack) {
|
|
T a;
|
|
pop(stack, a);
|
|
|
|
const int64_t size = a->elements().size();
|
|
push(stack, size);
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int listEq(Stack& stack) {
|
|
T a;
|
|
T b;
|
|
pop(stack, a, b);
|
|
push(stack, a->elements() == b->elements() ? true : false);
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int listNe(Stack& stack) {
|
|
T a;
|
|
T b;
|
|
pop(stack, a, b);
|
|
push(stack, !(a->elements() == b->elements()));
|
|
return 0;
|
|
}
|
|
|
|
inline bool tensor_list_equal(Shared<TensorList> a, Shared<TensorList> b) {
|
|
if (a->elements().size() != b->elements().size()) {
|
|
return false;
|
|
}
|
|
|
|
for (size_t i = 0; i < a->elements().size(); ++i) {
|
|
const auto& a_element = a->elements()[i];
|
|
const auto& b_element = b->elements()[i];
|
|
// This preserves Python's semantics, which uses eq() to compare two
|
|
// elements, then passes the result to bool().
|
|
// see: https://docs.python.org/3.4/reference/datamodel.html#object.__ge__
|
|
const auto cmp_result = a_element.eq(b_element);
|
|
if (!cmp_result.is_nonzero()) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Specialization for at::Tensor, since it doesn't define operator==
|
|
template <>
|
|
int listEq<Shared<TensorList>>(Stack& stack) {
|
|
Shared<TensorList> a;
|
|
Shared<TensorList> b;
|
|
pop(stack, a, b);
|
|
push(stack, tensor_list_equal(a, b));
|
|
return 0;
|
|
}
|
|
|
|
// Specialization for at::Tensor, since it doesn't define operator==
|
|
template <>
|
|
int listNe<Shared<TensorList>>(Stack& stack) {
|
|
Shared<TensorList> a;
|
|
Shared<TensorList> b;
|
|
pop(stack, a, b);
|
|
push(stack, !tensor_list_equal(a, b));
|
|
return 0;
|
|
}
|
|
|
|
Operation listList(const Node* node) {
|
|
return [=](Stack& stack) {
|
|
// Intentional no-op, needed to match Python semantics for list(iterable),
|
|
// but in JIT these will already be lists
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
template <class TList, class TElement>
|
|
int listAdd(Stack& stack) {
|
|
TList a;
|
|
TList b;
|
|
pop(stack, a, b);
|
|
|
|
std::vector<TElement> ret;
|
|
const auto total_size = a->elements().size() + b->elements().size();
|
|
ret.reserve(total_size);
|
|
for (const auto& a_element : a->elements()) {
|
|
ret.push_back(a_element);
|
|
}
|
|
for (const auto& b_element : b->elements()) {
|
|
ret.push_back(b_element);
|
|
}
|
|
|
|
push(stack, ret);
|
|
return 0;
|
|
}
|
|
|
|
template <class TList, class TElement>
|
|
int listMulIntLeft(Stack& stack) {
|
|
TList list;
|
|
int64_t n;
|
|
pop(stack, list, n);
|
|
|
|
std::vector<TElement> ret;
|
|
const auto size = list->elements().size() * n;
|
|
ret.reserve(size);
|
|
|
|
for (auto i = 0; i < n; i++) {
|
|
for (const auto& e : list->elements()) {
|
|
ret.push_back(e);
|
|
}
|
|
}
|
|
|
|
push(stack, ret);
|
|
return 0;
|
|
}
|
|
|
|
template <class TList, class TElement>
|
|
int listMulIntRight(Stack& stack) {
|
|
TList list;
|
|
int64_t n;
|
|
pop(stack, n, list);
|
|
|
|
std::vector<TElement> ret;
|
|
const auto size = list->elements().size() * n;
|
|
ret.reserve(size);
|
|
|
|
for (auto i = 0; i < n; i++) {
|
|
for (const auto& e : list->elements()) {
|
|
ret.push_back(e);
|
|
}
|
|
}
|
|
|
|
push(stack, ret);
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listSlice(Stack& stack) {
|
|
TList list;
|
|
int64_t start;
|
|
int64_t end;
|
|
int64_t step;
|
|
|
|
pop(stack, list, start, end, step);
|
|
const int64_t list_size = list->elements().size();
|
|
|
|
// clamp start and end to the bounds of the list
|
|
const auto normalized_start =
|
|
std::max((int64_t)0, normalizeIndex(start, list_size));
|
|
const auto normalized_end =
|
|
std::min(list_size, normalizeIndex(end, list_size));
|
|
|
|
std::vector<TElement> sliced_list;
|
|
if (normalized_end <= normalized_start) {
|
|
// early exit if the slice is trivially empty
|
|
push(stack, sliced_list);
|
|
return 0;
|
|
}
|
|
|
|
sliced_list.reserve(normalized_end - normalized_start);
|
|
|
|
for (auto i = normalized_start; i < normalized_end;) {
|
|
sliced_list.push_back(list->elements()[i]);
|
|
i += step;
|
|
}
|
|
|
|
push(stack, sliced_list);
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList>
|
|
int listSort(Stack& stack) {
|
|
TList list;
|
|
|
|
pop(stack, list);
|
|
std::sort(list->elements().begin(), list->elements().end());
|
|
return 0;
|
|
}
|
|
|
|
// Specialization for at::Tensor
|
|
template <>
|
|
int listSort<Shared<TensorList>>(Stack& stack) {
|
|
Shared<TensorList> list;
|
|
pop(stack, list);
|
|
std::sort(
|
|
list->elements().begin(),
|
|
list->elements().end(),
|
|
[](const at::Tensor& a, const at::Tensor& b) {
|
|
return a.lt(b).is_nonzero();
|
|
});
|
|
return 0;
|
|
}
|
|
|
|
template <typename TList, typename TElement>
|
|
int listSetItem(Stack& stack) {
|
|
TList list;
|
|
int64_t idx;
|
|
TElement value;
|
|
|
|
pop(stack, list, idx, value);
|
|
getItem(list, idx) = value;
|
|
|
|
push(stack, list);
|
|
return 0;
|
|
}
|
|
|
|
template <>
|
|
int listSetItem<Shared<BoolList>, bool>(Stack& stack) {
|
|
Shared<BoolList> list;
|
|
int64_t idx;
|
|
bool value;
|
|
|
|
pop(stack, list, idx, value);
|
|
|
|
int64_t list_size = list->elements().size();
|
|
auto normalized_idx = normalizeIndex(idx, list_size);
|
|
if (normalized_idx < 0 || normalized_idx >= list_size) {
|
|
throw std::out_of_range("list index out of range");
|
|
}
|
|
list->elements()[normalized_idx] = value;
|
|
|
|
push(stack, list);
|
|
return 0;
|
|
}
|
|
|
|
int dictSetItem(Stack& stack) {
|
|
auto value = pop(stack);
|
|
auto idx = pop(stack);
|
|
auto dict = pop(stack).toGenericDict();
|
|
dict->elements().insert_or_assign(std::move(idx), std::move(value));
|
|
push(stack, std::move(dict));
|
|
return 0;
|
|
}
|
|
|
|
int dictLen(Stack& stack) {
|
|
auto dict = pop(stack).toGenericDict();
|
|
push(stack, int64_t(dict->elements().size()));
|
|
return 0;
|
|
}
|
|
|
|
int dictKeys(Stack& stack) {
|
|
auto dict = pop(stack).toGenericDict();
|
|
std::vector<IValue> keys;
|
|
keys.reserve(dict->elements().size());
|
|
for (auto item : dict->elements()) {
|
|
keys.push_back(item.key());
|
|
}
|
|
push(stack, IValue(keys));
|
|
return 0;
|
|
}
|
|
|
|
template <typename Elem>
|
|
std::vector<Elem> makeListForDictValues(
|
|
const c10::ivalue::GenericDict::IterationOrder& order) {
|
|
std::vector<Elem> values;
|
|
values.reserve(order.size());
|
|
for (auto item : order) {
|
|
values.push_back(item.second.to<Elem>());
|
|
}
|
|
return values;
|
|
}
|
|
|
|
Operation dictValues(const Node* n) {
|
|
auto outputType = n->output()->type()->expect<ListType>();
|
|
return [=](Stack& stack) -> int {
|
|
const auto& order = pop(stack).toGenericDict()->iterationOrder();
|
|
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
|
|
push(stack, makeListForDictValues<at::Tensor>(order));
|
|
} else if (outputType->getElementType() == IntType::get()) {
|
|
push(stack, makeListForDictValues<int64_t>(order));
|
|
} else if (outputType->getElementType() == FloatType::get()) {
|
|
push(stack, makeListForDictValues<double>(order));
|
|
} else if (outputType->getElementType() == BoolType::get()) {
|
|
push(stack, makeListForDictValues<bool>(order));
|
|
} else {
|
|
push(stack, makeListForDictValues<IValue>(order));
|
|
}
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
int dictIndex(Stack& stack) {
|
|
auto index = pop(stack);
|
|
auto dict = pop(stack).toGenericDict();
|
|
const auto& elems = dict->elements();
|
|
auto value = elems.find(index);
|
|
if (value == elems.end()) {
|
|
AT_ERROR("KeyError: '", index, "'");
|
|
}
|
|
push(stack, value->value());
|
|
return 0;
|
|
}
|
|
|
|
int dictGet(Stack& stack) {
|
|
auto index = pop(stack);
|
|
auto dict = pop(stack).toGenericDict();
|
|
const auto& elems = dict->elements();
|
|
auto value = elems.find(index);
|
|
if (value == elems.end()) {
|
|
push(stack, IValue());
|
|
} else {
|
|
push(stack, value->value());
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int dictGetDefault(Stack& stack) {
|
|
auto default_value = pop(stack);
|
|
auto index = pop(stack);
|
|
auto dict = pop(stack).toGenericDict();
|
|
const auto& elems = dict->elements();
|
|
auto value = elems.find(index);
|
|
if (value == elems.end()) {
|
|
push(stack, default_value);
|
|
} else {
|
|
push(stack, value->value());
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
int hashValue(Stack& stack) {
|
|
auto value = pop(stack);
|
|
auto hash = std::hash<T>()(value.to<T>());
|
|
push(stack, int64_t(hash));
|
|
return 0;
|
|
}
|
|
|
|
RegisterOperators reg2({
|
|
|
|
#define DEFINE_STRING_OP(op_name, string_op, result) \
|
|
Operator(#op_name "(str a, str b) ->" #result, [](Stack& stack) { \
|
|
auto b = pop(stack).toStringRef(); \
|
|
auto a = pop(stack).toStringRef(); \
|
|
push(stack, string_op); \
|
|
return 0; \
|
|
})
|
|
|
|
DEFINE_STRING_OP(aten::eq, a == b, bool),
|
|
DEFINE_STRING_OP(aten::ne, a != b, bool),
|
|
DEFINE_STRING_OP(aten::add, a + b, str),
|
|
#undef DEFINE_STRING_OP
|
|
Operator(
|
|
"aten::len(str s) -> int",
|
|
[](Stack& stack) {
|
|
auto string = pop(stack).toStringRef();
|
|
push(stack, static_cast<int64_t>(string.size()));
|
|
return 0;
|
|
}),
|
|
// tensor length op (size of 1st dimension)
|
|
Operator(
|
|
"aten::len(Tensor t) -> int",
|
|
[](Stack& stack) {
|
|
at::Tensor t = pop(stack).toTensor();
|
|
if (t.dim() == 0) {
|
|
AT_ERROR("len() of a 0-d tensor");
|
|
}
|
|
push(stack, t.sizes()[0]);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::list(str t) -> str[]",
|
|
[](Stack& stack) {
|
|
auto str = pop(stack).toStringRef();
|
|
std::vector<IValue> chars;
|
|
chars.reserve(str.size());
|
|
for (auto c : str) {
|
|
chars.push_back(std::string(1, c));
|
|
}
|
|
push(stack, chars);
|
|
return 0;
|
|
}),
|
|
// Mutable ops for lists containing mutable types.
|
|
#define CREATE_MUTABLE_LIST_OPS(decl_type, c_type) \
|
|
Operator( \
|
|
"aten::select(" decl_type "[](a) list, int idx) -> " decl_type "(*)", \
|
|
listSelect<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::append( " decl_type "[](a!) self, " decl_type \
|
|
"(c -> *) el) -> " decl_type "[](a!)", \
|
|
listAppend<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::reverse( " decl_type "[](a!) self) -> ()", \
|
|
listReverse<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::extend(" decl_type "[](a!) self, " decl_type \
|
|
" [] other) -> ()", \
|
|
listExtend<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::copy(" decl_type \
|
|
"[](a) self)" \
|
|
" -> " decl_type "[]", \
|
|
listCopy<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
|
|
"(b -> *) el) -> " decl_type "[](a!)", \
|
|
listSetItem<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::clear( " decl_type "[](a!) self) -> ()", \
|
|
listClear<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::insert( " decl_type \
|
|
"[](a!) self, int idx, \
|
|
" decl_type "(b -> *) el) -> ()", \
|
|
listInsert<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::pop(" decl_type \
|
|
"[](a!) self, int idx=-1) \
|
|
-> " decl_type "(*)", \
|
|
listPop<Shared<c_type>>)
|
|
|
|
CREATE_MUTABLE_LIST_OPS("Tensor", TensorList),
|
|
|
|
Operator(
|
|
"aten::remove(Tensor[](a!) self, Tensor el) -> ()",
|
|
listRemove<Shared<TensorList>, at::Tensor>),
|
|
Operator(
|
|
"aten::index(Tensor[] self, Tensor el) -> int",
|
|
listIndex<Shared<TensorList>, at::Tensor>),
|
|
Operator(
|
|
"aten::count(Tensor[] self, Tensor el) -> int",
|
|
listCount<Shared<TensorList>, at::Tensor>),
|
|
|
|
// Mutable ops for lists containing immutable types.
|
|
#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
|
|
Operator( \
|
|
"aten::select(" decl_type "[] a, int b) -> " decl_type, \
|
|
listSelect<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::append(" decl_type "[](a!) self, " decl_type \
|
|
" el) -> " decl_type "[](a!)", \
|
|
listAppend<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::reverse(" decl_type "[](a!) self) -> ()", \
|
|
listReverse<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::extend(" decl_type "[](a!) self, " decl_type \
|
|
" [] other) -> ()", \
|
|
listExtend<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::copy(" decl_type \
|
|
"[](a) self)" \
|
|
" -> " decl_type "[]", \
|
|
listCopy<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
|
|
" el) -> " decl_type "[](a!)", \
|
|
listSetItem<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::clear( " decl_type "[](a!) self) -> ()", \
|
|
listClear<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::insert( " decl_type \
|
|
"[](a!) self, int idx, \
|
|
" decl_type " el) -> ()", \
|
|
listInsert<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::remove(" decl_type \
|
|
"[](a!) self, \
|
|
" decl_type " el) -> ()", \
|
|
listRemove<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::index(" decl_type \
|
|
"[] self, \
|
|
" decl_type " el) -> int", \
|
|
listIndex<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::count(" decl_type \
|
|
"[] self, \
|
|
" decl_type " el) -> int", \
|
|
listCount<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::pop(" decl_type \
|
|
"[](a!) self, int idx=-1) \
|
|
-> " decl_type, \
|
|
listPop<Shared<c_type>>)
|
|
|
|
CREATE_IMMUTABLE_LIST_OPS("int", IntList),
|
|
CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),
|
|
CREATE_IMMUTABLE_LIST_OPS("bool", BoolList),
|
|
|
|
// NOTE: this must be after the other list specializations so that operator
|
|
// resolution doesn't pick this up first
|
|
CREATE_MUTABLE_LIST_OPS("t", GenericList),
|
|
#undef CREATE_IMMUTABLE_LIST_OPS
|
|
#undef CREATE_MUTABLE_LIST_OPS
|
|
|
|
#define CREATE_LIST_OPS(decl_type, c_type) \
|
|
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
|
|
Operator( \
|
|
"aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
|
|
"[]", \
|
|
listAdd<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::slice(" decl_type \
|
|
"[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
|
|
"[]", \
|
|
listSlice<Shared<c_type>, c_type::ElemType>), \
|
|
Operator("aten::list(" decl_type "[] l) -> " decl_type "[]", listList), \
|
|
Operator( \
|
|
"aten::mul(" decl_type "[] l, int n) -> " decl_type "[]", \
|
|
listMulIntLeft<Shared<c_type>, c_type::ElemType>), \
|
|
Operator( \
|
|
"aten::mul(int n, " decl_type "[] l) -> " decl_type "[]", \
|
|
listMulIntRight<Shared<c_type>, c_type::ElemType>)
|
|
|
|
CREATE_LIST_OPS("int", IntList),
|
|
CREATE_LIST_OPS("float", DoubleList),
|
|
CREATE_LIST_OPS("bool", BoolList),
|
|
CREATE_LIST_OPS("Tensor", TensorList),
|
|
CREATE_LIST_OPS("t", GenericList),
|
|
#undef CREATE_LIST_OPS
|
|
Operator("aten::sort(int[](a!) self) -> ()", listSort<Shared<IntList>>),
|
|
Operator(
|
|
"aten::sort(float[](a!) self) -> ()",
|
|
listSort<Shared<DoubleList>>),
|
|
Operator(
|
|
"aten::sort(Tensor[](a!) self) -> ()",
|
|
listSort<Shared<TensorList>>),
|
|
Operator("aten::sort(bool[](a!) self) -> ()", listSort<Shared<BoolList>>),
|
|
|
|
Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
|
|
Operator(
|
|
"aten::eq(float[] a, float[] b) -> bool",
|
|
listEq<Shared<DoubleList>>),
|
|
Operator(
|
|
"aten::eq(Tensor[] a, Tensor[] b) -> bool",
|
|
listEq<Shared<TensorList>>),
|
|
Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq<Shared<BoolList>>),
|
|
Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
|
|
Operator(
|
|
"aten::ne(float[] a, float[] b) -> bool",
|
|
listNe<Shared<DoubleList>>),
|
|
Operator(
|
|
"aten::ne(Tensor[] a, Tensor[] b) -> bool",
|
|
listNe<Shared<TensorList>>),
|
|
Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
|
|
Operator(
|
|
"aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
|
|
stringSlice),
|
|
|
|
// python string is methods return false if empty
|
|
#define DEFINE_STRING_IS_OP(op_name, char_op) \
|
|
Operator(#op_name "(str self) -> bool", [](Stack& stack) { \
|
|
auto string = pop(stack).toStringRef(); \
|
|
push( \
|
|
stack, \
|
|
string.size() != 0 && \
|
|
std::all_of(string.begin(), string.end(), [](char c) { \
|
|
return char_op(c); \
|
|
})); \
|
|
return 0; \
|
|
})
|
|
|
|
// upper and lower require there to be at least one alpha character,
|
|
// and ignore all other characters
|
|
Operator(
|
|
"aten::isupper(str self) -> bool",
|
|
[](Stack& stack) {
|
|
auto string = pop(stack).toStringRef();
|
|
bool found_alpha = false;
|
|
bool is_upper = true;
|
|
for (size_t i = 0; i < string.size() && is_upper; ++i) {
|
|
char c = string[i];
|
|
found_alpha |= std::isalpha(c);
|
|
is_upper &= (!std::isalpha(c) || std::isupper(c));
|
|
}
|
|
push(stack, found_alpha && is_upper);
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::islower(str self) -> bool",
|
|
[](Stack& stack) {
|
|
auto string = pop(stack).toStringRef();
|
|
bool found_alpha = false;
|
|
bool is_lower = true;
|
|
for (size_t i = 0; i < string.size() && is_lower; ++i) {
|
|
char c = string[i];
|
|
found_alpha |= std::isalpha(c);
|
|
is_lower &= (!std::isalpha(c) || std::islower(c));
|
|
}
|
|
push(stack, found_alpha && is_lower);
|
|
return 0;
|
|
}),
|
|
|
|
DEFINE_STRING_IS_OP(aten::isdigit, std::isdigit),
|
|
DEFINE_STRING_IS_OP(aten::isspace, std::isspace),
|
|
DEFINE_STRING_IS_OP(aten::isalnum, std::isalnum),
|
|
DEFINE_STRING_IS_OP(aten::isalpha, std::isalpha),
|
|
|
|
#define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \
|
|
Operator(#op_name "(str self) -> str", [](Stack& stack) { \
|
|
auto string = pop(stack).toStringRef(); \
|
|
std::stringstream ss; \
|
|
for (char c : string) { \
|
|
ss << static_cast<char>(char_op(c)); \
|
|
} \
|
|
push(stack, ss.str()); \
|
|
return 0; \
|
|
})
|
|
|
|
DEFINE_STRING_CHAR_MAP_OP(aten::upper, std::toupper),
|
|
DEFINE_STRING_CHAR_MAP_OP(aten::lower, std::tolower),
|
|
|
|
Operator(
|
|
"prim::StringIndex(str string, int index) -> str",
|
|
[](Stack& stack) {
|
|
auto index = pop(stack).toInt();
|
|
auto string = pop(stack).toStringRef();
|
|
char c = string.at(index);
|
|
push(stack, std::string(&c, 1));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::str(t elem) -> str",
|
|
[](Stack& stack) {
|
|
std::stringstream ss;
|
|
ss << pop(stack);
|
|
push(stack, ss.str());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::ord(str string) -> int",
|
|
[](Stack& stack) {
|
|
auto string = pop(stack).toStringRef();
|
|
TORCH_CHECK(
|
|
string.size() == 1,
|
|
"String for ord() must be 1 character, found",
|
|
string.size());
|
|
uint8_t ord = string.at(0);
|
|
push(stack, int64_t(ord));
|
|
return 0;
|
|
}),
|
|
#define CREATE_COPY_OP(other_type, c_type) \
|
|
Operator( \
|
|
"aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
|
|
[](Stack& stack) { \
|
|
at::Tensor t; \
|
|
c_type other; \
|
|
pop(stack, t, other); \
|
|
std::move(t) = other; /* NOLINT(bugprone-use-after-move) */ \
|
|
push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \
|
|
return 0; \
|
|
})
|
|
|
|
CREATE_COPY_OP(Tensor, at::Tensor),
|
|
CREATE_COPY_OP(int, int64_t),
|
|
CREATE_COPY_OP(float, double),
|
|
#undef CREATE_COPY_OP
|
|
|
|
DEFINE_BINARY_OP(aten::add, a + b),
|
|
DEFINE_BINARY_OP(aten::sub, a - b),
|
|
DEFINE_BINARY_OP(aten::mul, a* b),
|
|
DEFINE_BINARY_OP(aten::pow, pow(a, b)),
|
|
// min and max are in prim:: because there is a difference between
|
|
// the python builtin 'min' and 'torch.min'
|
|
DEFINE_BINARY_OP(prim::min, a < b ? a : b),
|
|
DEFINE_BINARY_OP(prim::max, a > b ? a : b),
|
|
// Pass in two ops for handling int and float separately as % in C++ only
|
|
// works for int The modulus calculation is different between C++ and Python
|
|
// (on negative), we preserve the python behavior as it's more common and
|
|
// match python syntax, hence the conversion.
|
|
DEFINE_GENERIC_OP(
|
|
aten::remainder,
|
|
(b + (a % b)) % b,
|
|
fmod((b + fmod(a, b)), b),
|
|
int,
|
|
float),
|
|
DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float),
|
|
|
|
DEFINE_GENERIC_OP(
|
|
aten::floordiv,
|
|
floordiv(a, b),
|
|
std::floor(a / b),
|
|
int,
|
|
float),
|
|
DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float),
|
|
|
|
// only used in loop unrolling, not exposed to end users
|
|
DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b),
|
|
|
|
DEFINE_INT_OP(aten::__and__, a& b),
|
|
DEFINE_INT_OP(aten::__or__, a | b),
|
|
DEFINE_INT_OP(aten::__xor__, a ^ b),
|
|
|
|
Operator(
|
|
"prim::abs(int x) -> int",
|
|
[](Stack& stack) {
|
|
int64_t x;
|
|
pop(stack, x);
|
|
push(stack, std::abs(x));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::abs(float x) -> float",
|
|
[](Stack& stack) {
|
|
float x;
|
|
pop(stack, x);
|
|
push(stack, std::abs(x));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"prim::abs(Tensor x) -> Tensor",
|
|
[](Stack& stack) {
|
|
at::Tensor x;
|
|
pop(stack, x);
|
|
push(stack, x.abs());
|
|
return 0;
|
|
}),
|
|
|
|
// NB: This is the python truediv operation
|
|
Operator(
|
|
"aten::div(int a, int b) -> float",
|
|
[](Stack& stack) {
|
|
int64_t a, b;
|
|
pop(stack, a, b);
|
|
push(stack, static_cast<double>(a) / static_cast<double>(b));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::div(float a, float b) -> float",
|
|
[](Stack& stack) {
|
|
double a, b;
|
|
pop(stack, a, b);
|
|
push(stack, a / b);
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::floor(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::floor(a));
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::ceil(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::ceil(a));
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::log(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::log(a));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::log(int a) -> float",
|
|
[](Stack& stack) {
|
|
int64_t a;
|
|
pop(stack, a);
|
|
push(stack, std::log(a));
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::log1p(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::log1p(a));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::log1p(int a) -> float",
|
|
[](Stack& stack) {
|
|
int64_t a;
|
|
pop(stack, a);
|
|
push(stack, std::log1p(a));
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::log10(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::log10(a));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::log10(int a) -> float",
|
|
[](Stack& stack) {
|
|
int64_t a;
|
|
pop(stack, a);
|
|
push(stack, std::log10(a));
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::exp(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::exp(a));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::exp(int a) -> float",
|
|
[](Stack& stack) {
|
|
int64_t a;
|
|
pop(stack, a);
|
|
push(stack, std::exp(a));
|
|
return 0;
|
|
}),
|
|
|
|
Operator(
|
|
"aten::sqrt(float a) -> float",
|
|
[](Stack& stack) {
|
|
double a;
|
|
pop(stack, a);
|
|
push(stack, std::sqrt(a));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::sqrt(int a) -> float",
|
|
[](Stack& stack) {
|
|
int64_t a;
|
|
pop(stack, a);
|
|
push(stack, std::sqrt(a));
|
|
return 0;
|
|
}),
|
|
|
|
DEFINE_INT_OP(aten::gcd, gcd(a, b)),
|
|
|
|
DEFINE_GENERIC_OP(
|
|
aten::copysign,
|
|
std::copysign(a, b),
|
|
std::copysign(a, b),
|
|
float,
|
|
float),
|
|
DEFINE_INT_FLOAT_OP(aten::copysign, std::copysign(a, b), float),
|
|
|
|
#define DEFINE_MATH_OP(aten_op, op, int_result, float_result) \
|
|
Operator( \
|
|
#aten_op "(int a) -> " #int_result, \
|
|
[](Stack& stack) { \
|
|
int64_t a; \
|
|
pop(stack, a); \
|
|
push(stack, op); \
|
|
return 0; \
|
|
}), \
|
|
Operator(#aten_op "(float a) -> " #float_result, [](Stack& stack) { \
|
|
double a; \
|
|
pop(stack, a); \
|
|
push(stack, op); \
|
|
return 0; \
|
|
})
|
|
|
|
DEFINE_MATH_OP(aten::gamma, std::tgamma(a), float, float),
|
|
DEFINE_MATH_OP(aten::erf, std::erf(a), float, float),
|
|
DEFINE_MATH_OP(aten::erfc, std::erfc(a), float, float),
|
|
DEFINE_MATH_OP(aten::expm1, std::expm1(a), float, float),
|
|
DEFINE_MATH_OP(aten::fabs, std::fabs(a), float, float),
|
|
DEFINE_MATH_OP(aten::lgamma, std::lgamma(a), float, float),
|
|
|
|
DEFINE_COMPARISON_OP(aten::ne, a != b),
|
|
DEFINE_COMPARISON_OP(aten::eq, a == b),
|
|
DEFINE_COMPARISON_OP(aten::lt, a < b),
|
|
DEFINE_COMPARISON_OP(aten::gt, a > b),
|
|
DEFINE_COMPARISON_OP(aten::le, a <= b),
|
|
DEFINE_COMPARISON_OP(aten::ge, a >= b),
|
|
DEFINE_BOOL_OP(aten::__and__, a&& b),
|
|
DEFINE_BOOL_OP(aten::__or__, a || b),
|
|
DEFINE_BOOL_OP(aten::__xor__, a != b),
|
|
|
|
Operator(
|
|
"aten::neg(int self) -> int",
|
|
[](Stack& stack) {
|
|
push(stack, -pop(stack).toInt());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::neg(float self) -> float",
|
|
[](Stack& stack) {
|
|
push(stack, -pop(stack).toDouble());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::__not__(bool self) -> bool",
|
|
[](Stack& stack) {
|
|
push(stack, !pop(stack).toBool());
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::__is__(t1 self, t2 obj) -> bool",
|
|
[](Stack& stack) {
|
|
IValue self, obj;
|
|
pop(stack, self, obj);
|
|
push(stack, self.isSameIdentity(obj));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::__isnot__(t1 self, t2 obj) -> bool",
|
|
[](Stack& stack) {
|
|
IValue self, obj;
|
|
pop(stack, self, obj);
|
|
push(stack, !self.isSameIdentity(obj));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::_tensor_to_list(Tensor self) -> int[]",
|
|
[](Stack& stack) {
|
|
at::Tensor t;
|
|
pop(stack, t);
|
|
std::vector<int64_t> elems;
|
|
elems.reserve(t.size(0));
|
|
for (int i = 0; i < t.size(0); i++) {
|
|
elems.push_back(*t[i].data<int32_t>());
|
|
}
|
|
push(stack, jit::IntList::create(elems));
|
|
return 0;
|
|
}),
|
|
Operator(
|
|
"aten::_list_to_tensor(int[] self) -> Tensor",
|
|
[](Stack& stack) {
|
|
std::vector<int64_t> l;
|
|
pop(stack, l);
|
|
auto t = torch::empty(
|
|
{static_cast<int64_t>(l.size())}, at::dtype(at::kInt));
|
|
for (size_t i = 0; i < l.size(); i++) {
|
|
t[i] = l[i];
|
|
}
|
|
push(stack, t);
|
|
return 0;
|
|
}),
|
|
#define CREATE_DICT_OPS(key_type) \
|
|
Operator("aten::len(Dict(" key_type ", t) self) -> int", dictLen), \
|
|
Operator( \
|
|
"aten::keys(Dict(" key_type ", t) self) -> " key_type "[](*)", \
|
|
dictKeys), \
|
|
Operator( \
|
|
"aten::values(Dict(" key_type ", t) self) -> t[](*)", dictValues), \
|
|
Operator( \
|
|
"prim::DictIndex(Dict(" key_type ", t) self, " key_type \
|
|
" key) -> t(*)", \
|
|
dictIndex), \
|
|
Operator( \
|
|
"aten::get(Dict(" key_type ", t) self, " key_type " key) -> t(*)?", \
|
|
dictGet), \
|
|
Operator( \
|
|
"aten::get(Dict(" key_type ", t) self, " key_type \
|
|
" key, t default_value) -> t(*)", \
|
|
dictGetDefault), \
|
|
Operator( \
|
|
"aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \
|
|
" idx, t(b -> *) v) -> ()", \
|
|
dictSetItem)
|
|
|
|
CREATE_DICT_OPS("str"),
|
|
CREATE_DICT_OPS("int"),
|
|
CREATE_DICT_OPS("float"),
|
|
#undef CREATE_DICT_OPS
|
|
|
|
Operator("aten::hash(str t) -> int", hashValue<std::string>),
|
|
Operator("aten::hash(int t) -> int", hashValue<int>),
|
|
Operator("aten::hash(float t) -> int", hashValue<double>),
|
|
});
|
|
|
|
bool simpleClassTypeArg(const Argument& arg, const ClassTypePtr& type) {
|
|
return arg.type() == type && !arg.kwarg_only() && !arg.default_value();
|
|
}
|
|
|
|
void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) {
|
|
std::stringstream error_str;
|
|
if (auto class_type = list_element_type->cast<ClassType>()) {
|
|
if (auto method = class_type->getMethod("__lt__")) {
|
|
const auto& lt_schema = method->getSchema();
|
|
const auto& schema_args = lt_schema.arguments();
|
|
bool error =
|
|
(schema_args.size() != 2 ||
|
|
!simpleClassTypeArg(schema_args[0], class_type) ||
|
|
!simpleClassTypeArg(schema_args[1], class_type) ||
|
|
lt_schema.returns().size() != 1 ||
|
|
lt_schema.returns()[0].type() != BoolType::get());
|
|
if (!error) {
|
|
return;
|
|
}
|
|
}
|
|
error_str << "To sort a list of " << class_type->python_str()
|
|
<< " it must define a "
|
|
<< "__lt__ method with two inputs of type "
|
|
<< class_type->python_str() << " that "
|
|
<< "returns a bool";
|
|
} else {
|
|
error_str
|
|
<< "Input to list sort must be of Tensors, ints, floats, bools or "
|
|
<< "a User Defined Class that defines the __lt__ compare method"
|
|
<< ", got list of " << list_element_type->python_str() << "\n";
|
|
}
|
|
|
|
auto error_msg = script::ErrorReport(node->sourceRange());
|
|
error_msg << error_str.str();
|
|
throw error_msg;
|
|
}
|
|
|
|
// NB: this must be registered after the other aten::sort operators
|
|
RegisterOperators regSort({
|
|
Operator(
|
|
"aten::sort(t[](a!) self, bool reverse=False) -> ()",
|
|
[](const Node* node) {
|
|
const auto list_type =
|
|
node->inputs().at(0)->type()->expect<ListType>();
|
|
checkSortSchema(node, list_type->getElementType());
|
|
const auto elem = list_type->getElementType()->expect<ClassType>();
|
|
auto func = elem->getMethod("__lt__");
|
|
return [func](Stack& stack) {
|
|
bool reverse = pop(stack).toBool();
|
|
auto g_list = pop(stack).toGenericList();
|
|
Stack sort_stack;
|
|
std::sort(
|
|
g_list->elements().begin(),
|
|
g_list->elements().end(),
|
|
[func, reverse, &sort_stack](
|
|
const IValue& a, const IValue& b) -> bool {
|
|
// FBCode errors without this check - "strict weak ordering"
|
|
// TODO: remove when possible, since it just slows down
|
|
// sorting and doesn't do anything useful
|
|
if (a.isSameIdentity(b)) {
|
|
return false;
|
|
}
|
|
sort_stack.push_back(a);
|
|
sort_stack.push_back(b);
|
|
func->run(sort_stack);
|
|
return pop(sort_stack).toBool() ^ reverse;
|
|
});
|
|
return 0;
|
|
};
|
|
}),
|
|
});
|
|
|
|
// reference: _output_size in torch/nn/functional.py
|
|
// size can be none, int or intlist
|
|
// scale_factors can be none, float, or floatlist
|
|
std::vector<int64_t> _output_size(
|
|
const at::Tensor& input,
|
|
size_t dim,
|
|
const IValue& size,
|
|
const IValue& scale_factors) {
|
|
if (!size.isNone()) {
|
|
if (size.isInt()) {
|
|
std::vector<int64_t> repeated(dim, size.toInt());
|
|
return repeated;
|
|
} else {
|
|
return size.toIntListRef();
|
|
}
|
|
}
|
|
std::vector<double> scale_repeated;
|
|
if (scale_factors.isDouble()) {
|
|
scale_repeated = std::vector<double>(dim, scale_factors.toDouble());
|
|
} else {
|
|
scale_repeated = scale_factors.toDoubleListRef();
|
|
}
|
|
std::vector<int64_t> ret;
|
|
for (size_t i = 0; i < dim; ++i) {
|
|
ret.push_back(std::floor(input.size(i + 2) * scale_repeated[i]));
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
// reference: interpolate in torch/nn/functional.py
|
|
// size can be none, int or intlist
|
|
// scale_factors can be none, float, or floatlist
|
|
at::Tensor interpolate(
|
|
const at::Tensor& input,
|
|
const IValue& size,
|
|
const IValue& scale_factors,
|
|
const std::string& mode,
|
|
c10::optional<bool> align_corners) {
|
|
if ((mode == "nearest" || mode == "area")) {
|
|
if (align_corners != c10::nullopt) {
|
|
throw std::runtime_error(
|
|
"align_corners option can only be set with the "
|
|
"interpolating modes: linear | bilinear | bicubic | trilinear");
|
|
}
|
|
} else {
|
|
if (align_corners == c10::nullopt) {
|
|
AT_WARN(
|
|
"Default upsampling behavior when mode=",
|
|
mode,
|
|
" is changed "
|
|
"to align_corners=False since 0.4.0. Please specify align_corners=True "
|
|
"if the old behavior is desired. See the documentation of nn.Upsample for details");
|
|
align_corners = false;
|
|
}
|
|
}
|
|
|
|
auto input_dim = input.dim();
|
|
if (input_dim == 3 && mode == "nearest")
|
|
return at::upsample_nearest1d(
|
|
input, _output_size(input, 1, size, scale_factors));
|
|
if (input_dim == 4 && mode == "nearest")
|
|
return at::upsample_nearest2d(
|
|
input, _output_size(input, 2, size, scale_factors));
|
|
if (input_dim == 5 && mode == "nearest")
|
|
return at::upsample_nearest3d(
|
|
input, _output_size(input, 3, size, scale_factors));
|
|
if (input_dim == 3 && mode == "area")
|
|
return at::adaptive_avg_pool1d(
|
|
input, _output_size(input, 1, size, scale_factors));
|
|
if (input_dim == 4 && mode == "area")
|
|
return at::adaptive_avg_pool2d(
|
|
input, _output_size(input, 2, size, scale_factors));
|
|
if (input_dim == 5 && mode == "area")
|
|
return at::adaptive_avg_pool3d(
|
|
input, _output_size(input, 3, size, scale_factors));
|
|
if (input_dim == 3 && mode == "linear")
|
|
return at::upsample_linear1d(
|
|
input, _output_size(input, 1, size, scale_factors), *align_corners);
|
|
if (input_dim == 3 && mode == "bilinear")
|
|
throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input");
|
|
if (input_dim == 3 && mode == "bicubic")
|
|
throw std::runtime_error("Got 3D input, but bicubic mode needs 4D input");
|
|
if (input_dim == 3 && mode == "trilinear")
|
|
throw std::runtime_error("Got 3D input, but trilinear mode needs 5D input");
|
|
if (input_dim == 4 && mode == "linear")
|
|
throw std::runtime_error("Got 4D input, but linear mode needs 3D input");
|
|
if (input_dim == 4 && mode == "bilinear")
|
|
return at::upsample_bilinear2d(
|
|
input, _output_size(input, 2, size, scale_factors), *align_corners);
|
|
if (input_dim == 4 && mode == "bicubic")
|
|
return at::upsample_bicubic2d(
|
|
input, _output_size(input, 2, size, scale_factors), *align_corners);
|
|
if (input_dim == 4 && mode == "trilinear")
|
|
throw std::runtime_error("Got 4D input, but trilinear mode needs 5D input");
|
|
if (input_dim == 5 && mode == "linear")
|
|
throw std::runtime_error("Got 5D input, but linear mode needs 3D input");
|
|
if (input_dim == 5 && mode == "bilinear")
|
|
throw std::runtime_error("Got 5D input, but bilinear mode needs 4D input");
|
|
if (input_dim == 5 && mode == "bicubic")
|
|
throw std::runtime_error("Got 5D input, but bicubic mode needs 4D input");
|
|
if (input_dim == 5 && mode == "trilinear")
|
|
return at::upsample_trilinear3d(
|
|
input, _output_size(input, 3, size, scale_factors), *align_corners);
|
|
|
|
AT_ERROR(
|
|
"Input Error: Only 3D, 4D and 5D input Tensors supported",
|
|
" (got ",
|
|
input_dim,
|
|
"D) for the modes: nearest | linear | bilinear | trilinear",
|
|
" (got ",
|
|
mode,
|
|
") ");
|
|
}
|
|
|
|
Operation interpolate_op(const Node* n) {
|
|
return [](Stack& stack) {
|
|
at::Tensor input;
|
|
IValue size;
|
|
IValue scale_factors;
|
|
std::string mode;
|
|
IValue align_corners;
|
|
pop(stack, input, size, scale_factors, mode, align_corners);
|
|
at::Tensor res = interpolate(
|
|
input, size, scale_factors, mode, align_corners.toOptional<bool>());
|
|
push(stack, res);
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
// interpolate takes in float & float[] for scale factor
|
|
// upsample takes in int & int[], so convert the ints to floats before
|
|
// passing on to the interpolate op
|
|
IValue convert_scale_factor_to_double(const IValue& int_ivalue) {
|
|
IValue scale_factor_double;
|
|
if (int_ivalue.isInt()) {
|
|
scale_factor_double = static_cast<double>(int_ivalue.toInt());
|
|
} else if (int_ivalue.isIntList()) {
|
|
auto int_list = int_ivalue.toIntListRef();
|
|
std::vector<double> double_vec(int_list.begin(), int_list.end());
|
|
scale_factor_double = double_vec;
|
|
} else if (int_ivalue.isNone()) {
|
|
return IValue();
|
|
} else {
|
|
std::stringstream ss;
|
|
ss << "Expecting optional int or int list arg for scale factor, got"
|
|
<< int_ivalue;
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
return scale_factor_double;
|
|
}
|
|
|
|
Operation upsample_nearest_op(const Node* n) {
|
|
return [](Stack& stack) {
|
|
at::Tensor input;
|
|
IValue size;
|
|
IValue scale_factor_int;
|
|
pop(stack, input, size, scale_factor_int);
|
|
IValue scale_factor_double =
|
|
convert_scale_factor_to_double(scale_factor_int);
|
|
at::Tensor res =
|
|
interpolate(input, size, scale_factor_double, "nearest", c10::nullopt);
|
|
push(stack, res);
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
Operation upsample_op(const Node* n) {
|
|
return [](Stack& stack) {
|
|
at::Tensor input;
|
|
IValue size;
|
|
IValue scale_factor_int;
|
|
std::string mode;
|
|
IValue align_corners;
|
|
pop(stack, input, size, scale_factor_int, mode, align_corners);
|
|
IValue scale_factor_double =
|
|
convert_scale_factor_to_double(scale_factor_int);
|
|
at::Tensor res = interpolate(
|
|
input,
|
|
size,
|
|
scale_factor_double,
|
|
mode,
|
|
align_corners.toOptional<bool>());
|
|
push(stack, res);
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
Operation upsample_bilinear_op(const Node* n) {
|
|
return [](Stack& stack) {
|
|
at::Tensor input;
|
|
IValue size;
|
|
IValue scale_factor_int;
|
|
pop(stack, input, size, scale_factor_int);
|
|
IValue scale_factor_double =
|
|
convert_scale_factor_to_double(scale_factor_int);
|
|
at::Tensor res =
|
|
interpolate(input, size, scale_factor_double, "bilinear", true);
|
|
push(stack, res);
|
|
return 0;
|
|
};
|
|
}
|
|
|
|
RegisterOperators reg3({
|
|
Operator(
|
|
"aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
|
|
interpolate_op),
|
|
Operator(
|
|
"aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
|
|
interpolate_op),
|
|
Operator(
|
|
"aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
|
|
interpolate_op),
|
|
Operator(
|
|
"aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
|
|
interpolate_op),
|
|
|
|
Operator(
|
|
"aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
|
|
upsample_nearest_op),
|
|
Operator(
|
|
"aten::__upsample_nearest(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
|
|
upsample_nearest_op),
|
|
|
|
Operator(
|
|
"aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
|
|
upsample_op),
|
|
Operator(
|
|
"aten::__upsample(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
|
|
upsample_op),
|
|
|
|
Operator(
|
|
"aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
|
|
upsample_bilinear_op),
|
|
Operator(
|
|
"aten::__upsample_bilinear(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
|
|
upsample_bilinear_op),
|
|
Operator(
|
|
"aten::__upsample_bilinear(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
|
|
upsample_bilinear_op),
|
|
Operator(
|
|
"aten::__upsample_bilinear(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
|
|
upsample_bilinear_op),
|
|
|
|
});
|
|
|
|
at::Tensor leaky_relu(const at::Tensor& tensor, double scalar) {
|
|
return at::leaky_relu(tensor, scalar);
|
|
}
|
|
at::Tensor cat(const std::vector<at::Tensor>& tensors) {
|
|
return at::cat(tensors);
|
|
}
|
|
|
|
std::string get_first(const std::vector<std::vector<std::string>>& strings) {
|
|
return strings[0][0];
|
|
}
|
|
|
|
static auto reg4 =
|
|
torch::jit::RegisterOperators()
|
|
.op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
|
|
&leaky_relu)
|
|
.op("_test::cat(Tensor[] inputs) -> Tensor", &cat)
|
|
.op("_test::get_first", &get_first);
|
|
} // namespace
|
|
} // namespace jit
|
|
} // namespace torch
|