mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit ffb979032dc149b4c895526fe5b92d713ed7b1e1. Reverted https://github.com/pytorch/pytorch/pull/140225 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/140225#issuecomment-2469312229))
1539 lines
51 KiB
C++
1539 lines
51 KiB
C++
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <torch/csrc/jit/runtime/static/ops.h>
|
|
|
|
#include <ATen/CPUFunctions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#include <ATen/ScalarOps.h>
|
|
#include <ATen/TensorUtils.h>
|
|
#include <ATen/native/IndexingUtils.h>
|
|
#include <ATen/native/NonSymbolicBC.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <ATen/native/TensorAdvancedIndexing.h>
|
|
#include <c10/util/intrusive_ptr.h>
|
|
#include <c10/util/irange.h>
|
|
#include <c10/util/ssize.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
|
|
#include <torch/csrc/jit/runtime/register_ops_utils.h>
|
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
|
|
|
namespace {
|
|
constexpr auto createBorrowedIValue =
|
|
c10::MaybeOwnedTraits<c10::IValue>::createBorrow;
|
|
} // namespace
|
|
namespace torch::jit {
|
|
|
|
namespace {
|
|
|
|
std::vector<IValue> boxInputs(const ProcessedNode& pnode) {
|
|
std::vector<IValue> result;
|
|
for (const auto i : c10::irange(pnode.num_inputs())) {
|
|
result.push_back(pnode.Input(i));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor)
|
|
|
|
bool nativeOpIsRegistered(const c10::Symbol& op_name) {
|
|
const std::string name(op_name.toQualString());
|
|
return SRNativeOperatorRegistry()->Has(name);
|
|
}
|
|
|
|
SROperator getNativeOperation(Node* n) {
|
|
auto op_name = n->kind().toQualString();
|
|
if (SRNativeOperatorRegistry()->Has(op_name)) {
|
|
return SRNativeOperatorRegistry()->Create(op_name)->Generate(n);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TupleConstruct,
|
|
prim_TupleConstruct,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::TupleConstruct)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
// prepare inputs
|
|
auto stack = boxInputs(*p_node);
|
|
// run op
|
|
auto* node = p_node->node();
|
|
const auto& type = node->output()->type()->expect<TupleType>();
|
|
if (type->name().has_value()) {
|
|
namedTupleConstruct(stack, type, node->inputs().size());
|
|
} else {
|
|
tupleConstruct(stack, node->inputs().size());
|
|
}
|
|
// put output back
|
|
p_node->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TupleUnpack,
|
|
prim_TupleUnpack,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::TupleUnpack)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& elems = p_node->Input(0).toTupleRef().elements();
|
|
const size_t num_outputs = p_node->outputs().size();
|
|
TORCH_CHECK(
|
|
num_outputs == elems.size(),
|
|
"Number of outputs must match number of tuple elements.")
|
|
for (size_t i = 0; i < num_outputs; ++i) {
|
|
p_node->Output(i) = elems[i];
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::DictConstruct,
|
|
prim_DictConstruct,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::DictConstruct)) {
|
|
return nullptr;
|
|
}
|
|
auto dict_type = n->output()->type()->expect<DictType>();
|
|
const auto num_inputs = n->inputs().size();
|
|
TORCH_DCHECK_EQ(num_inputs % 2, 0);
|
|
return [dict_type = std::move(dict_type),
|
|
num_inputs,
|
|
dict_size = num_inputs / 2](ProcessedNode* p_node) {
|
|
auto result = c10::impl::GenericDict(
|
|
dict_type->containedType(0), dict_type->containedType(1));
|
|
result.reserve(dict_size);
|
|
for (size_t i = 0; i < num_inputs; i += 2) {
|
|
const auto& key = p_node->Input(i);
|
|
const auto& value = p_node->Input(i + 1);
|
|
result.insert_or_assign(key, value);
|
|
}
|
|
p_node->Output(0) = result;
|
|
};
|
|
});
|
|
|
|
// See [Borrowed IValue Outputs]
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
static_runtime::dict_unpack,
|
|
static_runtime_dict_unpack,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "static_runtime::dict_unpack(...) -> ...")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
DCHECK(
|
|
static_cast<size_t>(p_node->num_inputs() - 1) ==
|
|
p_node->outputs().size());
|
|
auto dict = p_node->Input(0).toGenericDict();
|
|
const auto num_inputs = p_node->num_inputs();
|
|
for (size_t i = 1; i < num_inputs; ++i) {
|
|
const auto& key = p_node->Input(i);
|
|
auto value = dict.find(key);
|
|
TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
|
|
p_node->Output(i - 1) = createBorrowedIValue(value->value());
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::__getitem__, aten_getitem, [](Node* n) -> SROperator {
|
|
if (!sr_schema_check(
|
|
n,
|
|
// TODO: "aten::__getitem__.str(str s, int index) -> str",
|
|
"aten::__getitem__.t(t[](a) list, int idx) -> t(*)",
|
|
"aten::__getitem__.Dict_str(Dict(str, t) self, str key) -> t(*)",
|
|
"aten::__getitem__.Dict_int(Dict(int, t) self, int key) -> t(*)",
|
|
"aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)",
|
|
"aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)",
|
|
"aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)",
|
|
"aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)")) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (n->inputs().size() != 2) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (n->input(0)->type()->castRaw<DictType>()) {
|
|
return [](ProcessedNode* p_node) {
|
|
auto dict = p_node->Input(0).toGenericDict();
|
|
const auto& key = p_node->Input(1);
|
|
auto value = dict.find(key);
|
|
TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
|
|
p_node->Output(0) = value->value();
|
|
};
|
|
} else if (n->input(0)->type()->castRaw<ListType>()) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& list = p_node->Input(0).toList();
|
|
auto idx = p_node->Input(1).toInt();
|
|
p_node->Output(0) = getItem(list, idx);
|
|
};
|
|
}
|
|
|
|
// TODO(T98581096): make __getitem__ work for other container types
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::ListConstruct,
|
|
prim_ListConstruct,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::ListConstruct)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
// prepare inputs
|
|
auto stack = boxInputs(*p_node);
|
|
// run op
|
|
listConstruct(
|
|
stack,
|
|
p_node->node()->output()->type()->expectRef<ListType>(),
|
|
p_node->num_inputs());
|
|
// put output back
|
|
p_node->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::ListUnpack,
|
|
prim_ListUnpack,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::ListUnpack)) {
|
|
return nullptr;
|
|
}
|
|
const auto num_outputs = n->outputs().size();
|
|
return [num_outputs](ProcessedNode* p_node) {
|
|
const auto list = p_node->Input(0).toListRef();
|
|
TORCH_CHECK(
|
|
list.size() == num_outputs,
|
|
"Expected ",
|
|
num_outputs,
|
|
" elements in list but got ",
|
|
list.size());
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
p_node->Output(i) = list[i];
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::append,
|
|
aten_append,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(
|
|
n, "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto list = p_node->Input(0).toList();
|
|
list.push_back(p_node->Input(1));
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::list,
|
|
aten_list,
|
|
[](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema("aten::list(str t) -> str[]"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto str = p_node->Input(0).toStringRef();
|
|
c10::List<std::string> chars;
|
|
chars.reserve(str.size());
|
|
for (auto c : str) {
|
|
chars.emplace_back(1, c);
|
|
}
|
|
p_node->Output(0) = std::move(chars);
|
|
};
|
|
}
|
|
|
|
if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto input = p_node->Input(0).toList();
|
|
p_node->Output(0) = input.copy();
|
|
};
|
|
}
|
|
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::numel,
|
|
aten_numel,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::numel(Tensor self) -> int")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& arg = p_node->Input(0).toTensor();
|
|
p_node->Output(0) = arg.numel();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::cpu,
|
|
aten_cpu,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::cpu(Tensor self) -> Tensor")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& arg = p_node->Input(0).toTensor();
|
|
p_node->Output(0) = arg.cpu();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::__range_length,
|
|
aten_range_length,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(
|
|
n, "aten::__range_length(int lo, int hi, int step) -> int")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto lo = p_node->Input(0).toInt();
|
|
auto hi = p_node->Input(1).toInt();
|
|
auto step = p_node->Input(2).toInt();
|
|
// error handling when step_val == 0 during runtime
|
|
if (step == 0) {
|
|
throw std::runtime_error("range() arg 3 must not be zero");
|
|
}
|
|
if (step > 0 && lo < hi) {
|
|
p_node->Output(0) = 1 + (hi - 1 - lo) / step;
|
|
} else if (step < 0 && lo > hi) {
|
|
p_node->Output(0) = 1 + (lo - 1 - hi) / (0 - step);
|
|
} else {
|
|
p_node->Output(0) = 0;
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema(
|
|
"aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor")) ||
|
|
n->matches(torch::schema(
|
|
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto& indices =
|
|
at::native::toListOfOptionalTensors(p_node->Input(1).toListRef());
|
|
const auto& values = p_node->Input(2).toTensor();
|
|
const auto accumulate = p_node->Input(3).toBool();
|
|
p_node->Output(0) =
|
|
at::native::index_put(self, indices, values, accumulate);
|
|
};
|
|
}
|
|
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::item,
|
|
aten_item,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::item(Tensor self) -> Scalar")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
p_node->Output(0) = at::native::item(self);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::GetAttr,
|
|
prim_GetAttr,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::GetAttr)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto& module = p_node->Input(0).toObjectRef();
|
|
Node* node = p_node->node();
|
|
const auto& type = node->input()->type()->expectRef<ClassType>();
|
|
const auto& field = node->s(attr::name);
|
|
const auto slot = type.getAttributeSlot(field);
|
|
p_node->Output(0) = module.getSlot(slot);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::SetAttr,
|
|
prim_SetAttr,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::SetAttr)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto& module = p_node->Input(0).toObjectRef();
|
|
Node* node = p_node->node();
|
|
const auto& type = node->inputs()[0]->type()->expectRef<ClassType>();
|
|
const auto& field = node->s(attr::name);
|
|
const auto slot = type.getAttributeSlot(field);
|
|
module.setSlot(slot, p_node->Input(1));
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::transpose,
|
|
aten_transpose,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toInt();
|
|
const auto in2_i = p_node->Input(2).toInt();
|
|
p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toInt();
|
|
const auto in2_i = p_node->Input(2).toInt();
|
|
p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::permute,
|
|
aten_permute,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_iv = p_node->Input(1).toDimVector();
|
|
p_node->Output(0) = at::native::permute(in0_t, in1_iv);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::reshape,
|
|
aten_reshape,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_iv = p_node->Input(1).toDimVector();
|
|
p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toInt();
|
|
const auto in2_i = p_node->Input(2).toOptional<int64_t>();
|
|
const auto in3_i = p_node->Input(3).toOptional<int64_t>();
|
|
const auto in4_i = p_node->Input(4).toInt();
|
|
p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
|
|
!n->matches(torch::schema(
|
|
"aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor(); // self
|
|
const auto dim = p_node->Input(1).toInt(); // dim
|
|
int64_t start = 0;
|
|
if (p_node->Input(2).isScalar()) {
|
|
start = p_node->Input(2).toInt();
|
|
} else {
|
|
auto& t = p_node->Input(2).toTensor();
|
|
start = t.item<int64_t>();
|
|
}
|
|
const auto length = p_node->Input(3).toInt(); // length
|
|
TORCH_CHECK(
|
|
self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
|
|
auto cur_size = self.sizes()[dim];
|
|
if (start != cur_size && start < 0) { // start being the end is valid, but
|
|
// not a valid dim specification.
|
|
start = at::maybe_wrap_dim(start, cur_size);
|
|
}
|
|
TORCH_CHECK(
|
|
length >= 0 && start <= cur_size - length,
|
|
"start (",
|
|
start,
|
|
") + length (",
|
|
length,
|
|
") exceeds dimension size (",
|
|
cur_size,
|
|
").");
|
|
p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema(
|
|
"aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto& in1_t = p_node->Input(1).toTensor();
|
|
const auto in2_i = p_node->Input(2).toBool();
|
|
const auto in3_i = p_node->Input(3).toBool();
|
|
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
|
p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
|
|
};
|
|
}
|
|
if (n->matches(torch::schema(
|
|
"aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toScalarType();
|
|
const auto in2_i = p_node->Input(2).toBool();
|
|
const auto in3_i = p_node->Input(3).toBool();
|
|
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
|
p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
|
|
};
|
|
}
|
|
if (n->matches(torch::schema(
|
|
"aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
const auto in1_i = p_node->Input(1).toOptional<at::ScalarType>();
|
|
const auto in2_i = p_node->Input(2).toBool();
|
|
const auto in3_i = p_node->Input(3).toBool();
|
|
// To mimick the behavior of the JIT interpreter, if both dtype
|
|
// and copy are not set, we return self. Otherwise, we assume
|
|
// that dtype is set.
|
|
if (!in1_i && !in3_i) {
|
|
p_node->Output(0) = in0_t;
|
|
} else {
|
|
TORCH_CHECK(
|
|
in1_i,
|
|
"dytpe cannot be None when copy is True for aten::to.prim_dtype");
|
|
p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i);
|
|
}
|
|
};
|
|
}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::detach,
|
|
aten_detach,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(
|
|
torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& in0_t = p_node->Input(0).toTensor();
|
|
p_node->Output(0) = at::native::alias(in0_t);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::expand_as,
|
|
aten_expand_as,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto& other = p_node->Input(1).toTensor();
|
|
p_node->Output(0) = self.expand(other.sizes());
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::isinstance,
|
|
prim_isinstance,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(
|
|
torch::schema("prim::isinstance(Any to_check) -> bool"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto input_type = p_node->Input(0).type();
|
|
|
|
auto* node = p_node->node();
|
|
const std::vector<TypePtr>& candidates = node->tys(attr::types);
|
|
for (const auto& candidate_type : candidates) {
|
|
if (input_type->isSubtypeOf(*candidate_type)) {
|
|
p_node->Output(0) = true;
|
|
return;
|
|
}
|
|
}
|
|
|
|
p_node->Output(0) = false;
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TypeCheck,
|
|
prim_TypeCheck,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::TypeCheck)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
auto* node = p_node->node();
|
|
const size_t num_inputs = node->inputs().size();
|
|
TORCH_INTERNAL_ASSERT(
|
|
num_inputs && num_inputs + 1 == node->outputs().size());
|
|
|
|
const auto& expected_types = node->tys(attr::types);
|
|
|
|
for (size_t i = 0; i < num_inputs; i++) {
|
|
p_node->Output(i) = p_node->Input(i);
|
|
}
|
|
|
|
for (size_t i = 0; i < num_inputs; i++) {
|
|
auto& input_tensor = p_node->Input(i).toTensor();
|
|
auto* expected_type = expected_types[i]->castRaw<TensorType>();
|
|
if (input_tensor.defined() &&
|
|
!expected_type->matchTensor(input_tensor)) {
|
|
p_node->Output(num_inputs) = false;
|
|
return;
|
|
}
|
|
}
|
|
|
|
p_node->Output(num_inputs) = true;
|
|
};
|
|
});
|
|
|
|
// See [Borrowed IValue Outputs]
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
static_runtime::VarTupleUnpack,
|
|
static_runtime_VarTupleUnpack,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "static_runtime::VarTupleUnpack(...) -> ...")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
size_t output_idx = 0;
|
|
for (const auto idx : c10::irange(pnode->num_inputs())) {
|
|
const auto& tuple = pnode->Input(idx);
|
|
for (auto& elem : tuple.toTupleRef().elements()) {
|
|
pnode->Output(output_idx) = createBorrowedIValue(elem);
|
|
++output_idx;
|
|
}
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::view,
|
|
aten_view,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& input = p_node->Input(0).toTensor();
|
|
const auto size = p_node->Input(1).toIntList();
|
|
p_node->Output(0) = at::native::view(input, size.vec());
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::size,
|
|
aten_size,
|
|
[](Node* n) -> SROperator {
|
|
if (n->matches(
|
|
torch::schema("aten::size(Tensor self, int dim) -> int"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& input = p_node->Input(0).toTensor();
|
|
auto dim = p_node->Input(1).toInt();
|
|
const auto ndim = input.dim();
|
|
|
|
if (dim < 0 || dim >= ndim) {
|
|
dim = c10::maybe_wrap_dim(dim, ndim);
|
|
}
|
|
p_node->Output(0) = input.sizes()[dim];
|
|
};
|
|
}
|
|
if (n->matches(torch::schema("aten::size(Tensor self) -> int[]"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& input = p_node->Input(0).toTensor();
|
|
p_node->Output(0) = input.sizes();
|
|
};
|
|
}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::squeeze,
|
|
aten_squeeze,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto dim = p_node->Input(1).toInt();
|
|
p_node->Output(0) = at::native::squeeze(self, dim);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema(
|
|
"aten::split(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto split_size = p_node->Input(1).toInt();
|
|
const auto dim = p_node->Input(2).toInt();
|
|
p_node->Output(0) = at::native::split(self, split_size, dim);
|
|
};
|
|
}
|
|
|
|
if (n->matches(torch::schema(
|
|
"aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) {
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto& split_sizes = p_node->Input(1).toIntList();
|
|
const auto dim = p_node->Input(2).toInt();
|
|
p_node->Output(0) =
|
|
at::native::split_with_sizes(self, split_sizes.vec(), dim);
|
|
};
|
|
}
|
|
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::split_with_sizes,
|
|
aten_split_with_sizes,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]")) &&
|
|
!n->matches(torch::schema(
|
|
"aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto& self = p_node->Input(0).toTensor();
|
|
const auto& split_sizes = p_node->Input(1).toIntList();
|
|
const auto dim = p_node->Input(2).toInt();
|
|
p_node->Output(0) =
|
|
at::native::split_with_sizes(self, split_sizes.vec(), dim);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
static_runtime::select_tensor,
|
|
aten_select_tensor,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(
|
|
n,
|
|
"static_runtime::select_tensor(Tensor(a) a, Tensor(b) b, bool use_b) -> Tensor(a|b)")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto did_copy = p_node->Input(2).toBool();
|
|
DCHECK(p_node->Input(0).isTensor());
|
|
DCHECK(!did_copy || p_node->Input(1).isTensor());
|
|
const IValue& assignFrom =
|
|
did_copy ? p_node->Input(1) : p_node->Input(0);
|
|
// Create an IValue that borrows the input Tensor in order to
|
|
// save a refcount increment here and decrement in
|
|
// MemoryPlanner::deallocate. MemoryPlanner knows about this
|
|
// and will safely clean it up by using the corresponding
|
|
// destroyBorrow method.
|
|
TORCH_DCHECK_NE(&assignFrom, &p_node->Output(0));
|
|
// MemoryPlanner should have cleaned this up!
|
|
DCHECK(p_node->Output(0).isNone());
|
|
p_node->Output(0) =
|
|
IValue(c10::MaybeOwnedTraits<at::TensorBase>::createBorrow(
|
|
assignFrom.toTensor()));
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::mul,
|
|
aten_mul,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(
|
|
torch::schema("aten::mul.left_t(t[] l, int n) -> (t[])"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& list = pnode->Input(0).toList();
|
|
const auto n = pnode->Input(1).toInt();
|
|
|
|
auto list_type = list.elementType();
|
|
auto ret = c10::impl::GenericList(list_type);
|
|
ret.reserve(list.size() * n);
|
|
for (const auto i : c10::irange(n)) {
|
|
(void)i;
|
|
for (const auto& ival : list) {
|
|
ret.push_back(ival);
|
|
}
|
|
}
|
|
pnode->Output(0) = ret;
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::sub,
|
|
aten_sub,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema("aten::sub.int(int a, int b) -> (int)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto a = pnode->Input(0).toInt();
|
|
const auto b = pnode->Input(1).toInt();
|
|
pnode->Output(0) = a - b;
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::add,
|
|
aten_add,
|
|
[](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema("aten::add.t(t[] a, t[] b) -> (t[])"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& a = pnode->Input(0).toList();
|
|
const auto& b = pnode->Input(1).toList();
|
|
auto ret = a.copy();
|
|
ret.append(b);
|
|
pnode->Output(0) = ret;
|
|
};
|
|
}
|
|
|
|
if (n->matches(torch::schema("aten::add.int(int a, int b) -> (int)"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto a = pnode->Input(0).toInt();
|
|
const auto b = pnode->Input(1).toInt();
|
|
pnode->Output(0) = a + b;
|
|
};
|
|
}
|
|
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema(
|
|
"aten::tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& a = pnode->Input(0).toTensor();
|
|
const auto& b = pnode->Input(1).toIntVector();
|
|
const auto c = pnode->Input(2).toInt();
|
|
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
|
};
|
|
}
|
|
|
|
if (n->matches(torch::schema(
|
|
"aten::tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& a = pnode->Input(0).toTensor();
|
|
const auto b = pnode->Input(1).toSymInt();
|
|
const auto c = pnode->Input(2).toInt();
|
|
pnode->Output(0) = at::native::tensor_split_sections_symint(a, b, c);
|
|
};
|
|
}
|
|
|
|
if (n->matches(torch::schema(
|
|
"aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& a = pnode->Input(0).toTensor();
|
|
const auto& b = pnode->Input(1).toTensor();
|
|
const auto c = pnode->Input(2).toInt();
|
|
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
|
};
|
|
}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::Int,
|
|
aten_Int,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema("aten::Int(Tensor a) -> int"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
pnode->Output(0) = at::native::item(input).toInt();
|
|
};
|
|
});
|
|
|
|
// See [Create owned refs for special values]
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
static_runtime::create_owned_ref,
|
|
static_runtime_create_owned_ref,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "static_runtime::create_owned_ref(...) -> ...")) {
|
|
return nullptr;
|
|
}
|
|
return
|
|
[](ProcessedNode* p_node) { p_node->Output(0) = p_node->Input(0); };
|
|
});
|
|
|
|
namespace {
|
|
bool outputsEmpty(const Block* block) {
|
|
return block->outputs().size() == 1 && block->outputs().at(0)->mustBeNone();
|
|
}
|
|
|
|
bool blockEmpty(const Block* block) {
|
|
return block->nodes().begin() == block->nodes().end();
|
|
}
|
|
|
|
enum class BlockRunPlan : int8_t {
|
|
kRunOnlyTrueBlock,
|
|
kRunOnlyFalseBlock,
|
|
kRunBothBlocks,
|
|
kRunNeitherBlock,
|
|
};
|
|
} // namespace
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::If,
|
|
prim_If,
|
|
[](Node* node) -> SROperator {
|
|
if (!sr_schema_check_kind(node, prim::If)) {
|
|
return nullptr;
|
|
}
|
|
TORCH_DCHECK_EQ(node->blocks().size(), 2);
|
|
const Block* true_block = node->blocks().at(0);
|
|
const Block* false_block = node->blocks().at(1);
|
|
|
|
const bool true_block_returns_empty = outputsEmpty(true_block);
|
|
const bool false_block_returns_empty = outputsEmpty(false_block);
|
|
|
|
BlockRunPlan block_run_plan = BlockRunPlan::kRunNeitherBlock;
|
|
|
|
if (true_block_returns_empty && false_block_returns_empty) {
|
|
const bool false_block_is_empty = blockEmpty(false_block);
|
|
const bool true_block_is_empty = blockEmpty(true_block);
|
|
|
|
if (false_block_is_empty && !true_block_is_empty) {
|
|
block_run_plan = BlockRunPlan::kRunOnlyTrueBlock;
|
|
} else if (!false_block_is_empty && true_block_is_empty) {
|
|
block_run_plan = BlockRunPlan::kRunOnlyFalseBlock;
|
|
} else if (false_block_is_empty && true_block_is_empty) {
|
|
block_run_plan = BlockRunPlan::kRunNeitherBlock;
|
|
} else {
|
|
block_run_plan = BlockRunPlan::kRunBothBlocks;
|
|
}
|
|
} else {
|
|
block_run_plan = BlockRunPlan::kRunBothBlocks;
|
|
}
|
|
|
|
switch (block_run_plan) {
|
|
case BlockRunPlan::kRunBothBlocks:
|
|
return [](ProcessedNode* p_node) {
|
|
auto condition = p_node->Input(0).toBool();
|
|
auto* metadata = p_node->metadata();
|
|
DCHECK(metadata);
|
|
auto& block_runners = metadata->block_runners();
|
|
TORCH_DCHECK_EQ(block_runners.size(), 2);
|
|
auto& runner = block_runners[!condition];
|
|
|
|
auto output = runner({});
|
|
// If we are returning a tuple, we are either returning
|
|
// multiple unpacked values or all of the values wrapped
|
|
// in a single tuple. The second condition handles the
|
|
// the latter case.
|
|
if (!output.isTuple() || p_node->num_outputs() == 1) {
|
|
p_node->Output(0) = std::move(output);
|
|
return;
|
|
}
|
|
auto& elems = output.toTupleRef().elements();
|
|
TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs());
|
|
for (const auto i : c10::irange(elems.size())) {
|
|
p_node->Output(i) = elems[i];
|
|
}
|
|
};
|
|
case BlockRunPlan::kRunOnlyTrueBlock:
|
|
return [](ProcessedNode* p_node) {
|
|
auto condition = p_node->Input(0).toBool();
|
|
auto* metadata = p_node->metadata();
|
|
DCHECK(metadata);
|
|
auto& block_runners = metadata->block_runners();
|
|
TORCH_DCHECK_EQ(block_runners.size(), 2);
|
|
if (condition) {
|
|
auto output = block_runners.front()({});
|
|
DCHECK(output.isNone());
|
|
}
|
|
};
|
|
case BlockRunPlan::kRunOnlyFalseBlock:
|
|
return [](ProcessedNode* p_node) {
|
|
auto condition = p_node->Input(0).toBool();
|
|
auto* metadata = p_node->metadata();
|
|
DCHECK(metadata);
|
|
auto& block_runners = metadata->block_runners();
|
|
TORCH_DCHECK_EQ(block_runners.size(), 2);
|
|
if (!condition) {
|
|
auto output = block_runners.back()({});
|
|
DCHECK(output.isNone());
|
|
}
|
|
};
|
|
case BlockRunPlan::kRunNeitherBlock:
|
|
return [](ProcessedNode*) {};
|
|
}
|
|
return [](ProcessedNode*) {};
|
|
});
|
|
|
|
namespace {
|
|
|
|
std::vector<IValue> collectLoopSubBlockInputs(const ProcessedNode& p_node) {
|
|
const auto num_inputs = p_node.num_inputs();
|
|
TORCH_DCHECK_GE(num_inputs, 2);
|
|
// The first two inputs to the loop node are the max trip count
|
|
// and initial condition. We don't collect them here, since those
|
|
// are not inputs for the sub-block.
|
|
const auto num_args = num_inputs - 2;
|
|
|
|
std::vector<IValue> result;
|
|
result.reserve(num_args + 1);
|
|
// First argument to the loop sub-block is always the loop counter, initially
|
|
// zero.
|
|
result.emplace_back(0);
|
|
|
|
for (const auto i : c10::irange(num_args)) {
|
|
result.push_back(p_node.Input(2 + i));
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace {
|
|
/*
|
|
ForkedSubgraphSRLauncher is responsible for the execution of
|
|
forked subgraph on new instance of static runtime. Once the
|
|
execution is completed, future is marked as complete to
|
|
indicate aten::wait() to proceed
|
|
*/
|
|
class TORCH_API ForkedSubgraphSRLauncher {
|
|
public:
|
|
ForkedSubgraphSRLauncher(
|
|
std::shared_ptr<StaticModule> smodule,
|
|
std::vector<IValue> args,
|
|
c10::intrusive_ptr<Future> future,
|
|
TaskLauncher launcher)
|
|
: smodule_(std::move(smodule)),
|
|
args_(std::move(args)),
|
|
future_(std::move(future)),
|
|
launcher_(std::move(launcher)) {}
|
|
|
|
void operator()() {
|
|
try {
|
|
StaticRuntime runtime(*smodule_);
|
|
auto future_subgraph = runtime.runAsync(args_, {}, launcher_);
|
|
future_subgraph->waitAndThrow();
|
|
future_->markCompleted(future_subgraph->value());
|
|
} catch (const std::exception& e) {
|
|
future_->setErrorIfNeeded(
|
|
std::make_exception_ptr(c10::ivalue::Future::FutureError(e.what())));
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<StaticModule> smodule_;
|
|
std::vector<IValue> args_;
|
|
c10::intrusive_ptr<Future> future_;
|
|
torch::jit::TaskLauncher launcher_;
|
|
};
|
|
|
|
/*
|
|
helper function to create a future on return type
|
|
of the graph outputs. This function is utilized by
|
|
prim::fork and aten::wait operations for async
|
|
execution of subgraphs
|
|
*/
|
|
c10::intrusive_ptr<Future> createFutureTypeFromGraphOutput(
|
|
const std::shared_ptr<torch::jit::Graph>& graph) {
|
|
TypePtr return_type_;
|
|
if (graph->outputs().size() == 1) {
|
|
return_type_ = graph->outputs().at(0)->type();
|
|
} else {
|
|
return_type_ = TupleType::create(
|
|
fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
|
|
}
|
|
c10::intrusive_ptr<Future> future = c10::make_intrusive<Future>(return_type_);
|
|
return future;
|
|
}
|
|
} // namespace
|
|
|
|
/*
|
|
prim::fork forks the execution of a subgraph. It returns a future on which
|
|
the corresponding aten::wait op waits until future is marked complete
|
|
Current implementation creates a instance of StaticModule uses it to
|
|
create StaticRuntime instances on the fly during runtime to handle the
|
|
execution of forked subgraph. Async execution is handled by
|
|
aten::ParallelThreadPoolNative threadpool.
|
|
*/
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::fork,
|
|
prim_Fork,
|
|
[](Node* node) -> SROperator {
|
|
if (!sr_schema_check_kind(node, prim::fork)) {
|
|
return nullptr;
|
|
}
|
|
auto forkedGraph = node->g(attr::Subgraph);
|
|
Inline(*forkedGraph);
|
|
auto sr_metadata = node->ival(getStaticRuntimeMetadataSymbol())
|
|
.toCustomClass<StaticRuntimeMetadata>();
|
|
auto smodule =
|
|
std::make_shared<StaticModule>(forkedGraph, sr_metadata->get_opts());
|
|
|
|
return [forkedGraph = std::move(forkedGraph),
|
|
smodule = std::move(smodule)](ProcessedNode* p_node) {
|
|
std::vector<IValue> args;
|
|
args.reserve(p_node->num_inputs());
|
|
for (const auto i : c10::irange(p_node->num_inputs())) {
|
|
args.push_back(p_node->Input(i));
|
|
}
|
|
|
|
c10::intrusive_ptr<Future> future =
|
|
createFutureTypeFromGraphOutput(forkedGraph);
|
|
p_node->Output(0) = future;
|
|
|
|
auto* metadata = p_node->metadata();
|
|
DCHECK(metadata);
|
|
auto* launcher = metadata->launcher();
|
|
DCHECK(launcher);
|
|
ForkedSubgraphSRLauncher runtime_launcher(
|
|
smodule, args, future, *launcher);
|
|
(*launcher)(std::move(runtime_launcher));
|
|
};
|
|
});
|
|
/*
|
|
aten::wait waits on the future (present in corresponding fork)
|
|
to be executed. Once the execution is complete, the future is marked
|
|
completed and wait execution continues.
|
|
*/
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::wait,
|
|
aten_Wait,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::wait(Future(t) self) -> t")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
TORCH_INTERNAL_ASSERT(p_node->Input(0).isFuture());
|
|
auto future = p_node->Input(0).toFuture();
|
|
|
|
// blocking call: waiting for the future to be completed
|
|
future->waitAndThrow();
|
|
|
|
TORCH_INTERNAL_ASSERT(future->completed());
|
|
TORCH_INTERNAL_ASSERT(!future->hasError());
|
|
TORCH_INTERNAL_ASSERT(future->hasValue());
|
|
|
|
if (!future->value().isTuple()) {
|
|
p_node->Output(0) = future->value();
|
|
return;
|
|
}
|
|
auto& elems = future->value().toTupleRef().elements();
|
|
TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs());
|
|
for (const auto i : c10::irange(elems.size())) {
|
|
p_node->Output(i) = elems[i];
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::Loop,
|
|
prim_Loop,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::Loop)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* p_node) {
|
|
const auto max_trip_count = p_node->Input(0).toInt();
|
|
auto condition = p_node->Input(1).toBool();
|
|
|
|
auto* metadata = p_node->metadata();
|
|
DCHECK(metadata);
|
|
auto& block_runners = metadata->block_runners();
|
|
TORCH_DCHECK_EQ(block_runners.size(), 1);
|
|
auto& runner = block_runners[0];
|
|
|
|
auto args = collectLoopSubBlockInputs(*p_node);
|
|
int64_t loop_count = 0;
|
|
|
|
while (condition && loop_count < max_trip_count) {
|
|
auto output = runner(args);
|
|
|
|
if (output.isTuple()) {
|
|
auto& elems = output.toTupleRef().elements();
|
|
DCHECK(elems.size() == args.size());
|
|
for (const auto i : c10::irange(1, args.size())) {
|
|
args[i] = elems[i];
|
|
}
|
|
condition = elems[0].toBool();
|
|
} else {
|
|
condition = output.toBool();
|
|
}
|
|
args[0] = ++loop_count;
|
|
}
|
|
|
|
const auto num_outputs = p_node->num_outputs();
|
|
TORCH_DCHECK_EQ(args.size(), num_outputs + 1);
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
p_node->Output(i) = std::move(args[i + 1]);
|
|
}
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::CreateObject,
|
|
prim_CreateObject,
|
|
[](Node* node) -> SROperator {
|
|
if (!sr_schema_check_kind(node, prim::CreateObject)) {
|
|
return nullptr;
|
|
}
|
|
auto class_type = node->output()->type()->expect<ClassType>();
|
|
return [class_type = std::move(class_type)](ProcessedNode* pnode) {
|
|
pnode->Output(0) = c10::ivalue::Object::create(
|
|
c10::StrongTypePtr(class_type->compilation_unit(), class_type),
|
|
class_type->numAttributes());
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::TupleIndex,
|
|
prim_TupleIndex,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::TupleIndex)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& elems = pnode->Input(0).toTupleRef().elements();
|
|
using c10::ssize;
|
|
const auto num_elems = ssize(elems);
|
|
const auto idx = pnode->Input(1).toInt();
|
|
const auto norm_idx = normalizeIndex(idx, num_elems);
|
|
if (norm_idx < 0 || norm_idx >= num_elems) {
|
|
// Use std::runtime_error instead of c10::Error to be consistent with
|
|
// JIT
|
|
throw std::out_of_range("Tuple index out of range");
|
|
}
|
|
pnode->Output(0) = elems[norm_idx];
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::RaiseException,
|
|
prim_RaiseException,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::RaiseException)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& message = pnode->Input(0).toStringRef();
|
|
throw std::runtime_error(message);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::Uninitialized,
|
|
prim_Uninitialized,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::Uninitialized)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
pnode->Output(0) = IValue::uninitialized();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::format,
|
|
aten_format,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::format(str self, ...) -> str")) {
|
|
return nullptr;
|
|
}
|
|
TORCH_CHECK(!n->inputs().empty());
|
|
return [](ProcessedNode* pnode) {
|
|
const auto num_inputs = pnode->num_inputs();
|
|
auto stack = boxInputs(*pnode);
|
|
format(stack, num_inputs);
|
|
TORCH_DCHECK_EQ(stack.size(), 1);
|
|
pnode->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::device,
|
|
prim_device,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "prim::device(Tensor a) -> Device")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
pnode->Output(0) = input.device();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::dtype,
|
|
prim_dtype,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::dtype)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
pnode->Output(0) = static_cast<int64_t>(input.scalar_type());
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::dim,
|
|
aten_dim,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::dim(Tensor self) -> int")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
pnode->Output(0) = input.dim();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::__not__,
|
|
aten_not,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "aten::__not__(bool self) -> bool")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
auto input = pnode->Input(0).toBool();
|
|
pnode->Output(0) = !input;
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::Bool,
|
|
aten_Bool,
|
|
[](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema("aten::Bool.Tensor(Tensor a) -> bool"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
pnode->Output(0) = at::native::is_nonzero(input);
|
|
};
|
|
}
|
|
if (n->matches(torch::schema("aten::Bool.int(int a) -> bool"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto input = pnode->Input(0).toInt();
|
|
pnode->Output(0) = static_cast<bool>(input);
|
|
};
|
|
}
|
|
if (n->matches(torch::schema("aten::Bool.float(float a) -> bool"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto input = pnode->Input(0).toDouble();
|
|
pnode->Output(0) = static_cast<bool>(input);
|
|
};
|
|
}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::is_cuda,
|
|
prim_is_cuda,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check(n, "prim::is_cuda(Tensor a) -> bool")) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
pnode->Output(0) = input.is_cuda();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::tolist,
|
|
prim_tolist,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::tolist)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& input = pnode->Input(0).toTensor();
|
|
const auto dim = pnode->Input(1).toInt();
|
|
const auto elem_type = pnode->Input(2).toInt();
|
|
std::vector<IValue> stack{input, dim, elem_type};
|
|
toList(stack);
|
|
TORCH_DCHECK_EQ(stack.size(), 1);
|
|
pnode->Output(0) = std::move(stack[0]);
|
|
};
|
|
});
|
|
|
|
// See [Borrowed IValue Outputs]
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
prim::IfThenElse,
|
|
prim_IfThenElse,
|
|
[](Node* n) -> SROperator {
|
|
if (!sr_schema_check_kind(n, prim::IfThenElse)) {
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto condition = pnode->Input(0).toBool();
|
|
pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1))
|
|
: createBorrowedIValue(pnode->Input(2));
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::len,
|
|
aten_len,
|
|
[](Node* n) -> SROperator {
|
|
if (n->matches(torch::schema("aten::len.t(t[] a) -> int")) ||
|
|
n->matches(torch::schema("aten::len.any(Any[] a) -> int"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto list = pnode->Input(0).toListRef();
|
|
const int64_t size = list.size();
|
|
pnode->Output(0) = size;
|
|
};
|
|
}
|
|
if (n->matches(torch::schema("aten::len.Tensor(Tensor t) -> int"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& t = pnode->Input(0).toTensor();
|
|
TORCH_CHECK(t.dim() > 0);
|
|
pnode->Output(0) = t.sizes()[0];
|
|
};
|
|
}
|
|
if (n->matches(torch::schema("aten::len.str(str s) -> int"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& string = pnode->Input(0).toStringRef();
|
|
pnode->Output(0) = static_cast<int64_t>(string.size());
|
|
};
|
|
}
|
|
if (n->matches(
|
|
torch::schema("aten::len.Dict_str(Dict(str, t) self) -> int")) ||
|
|
n->matches(
|
|
torch::schema("aten::len.Dict_int(Dict(int, t) self) -> int")) ||
|
|
n->matches(torch::schema(
|
|
"aten::len.Dict_bool(Dict(bool, t) self) -> int")) ||
|
|
n->matches(torch::schema(
|
|
"aten::len.Dict_float(Dict(float, t) self) -> int")) ||
|
|
n->matches(torch::schema(
|
|
"aten::len.Dict_complex(Dict(complex, t) self) -> int")) ||
|
|
n->matches(torch::schema(
|
|
"aten::len.Dict_Tensor(Dict(Tensor, t) self) -> int"))) {
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& dict = pnode->Input(0).toGenericDict();
|
|
pnode->Output(0) = static_cast<int64_t>(dict.size());
|
|
};
|
|
}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::IntImplicit,
|
|
aten_IntImplicit,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema("aten::IntImplicit(Tensor a) -> int"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& tensor = pnode->Input(0).toTensor();
|
|
// JIT does a check for requires_grad, but we skip it here since SR is
|
|
// inference only
|
|
if (!tensor.sizes().empty()) {
|
|
throw std::runtime_error(
|
|
"Cannot convert a tensor of dimension > 0 to scalar");
|
|
}
|
|
if (!isIntegralType(tensor.scalar_type(), /*includeBool=*/false)) {
|
|
std::stringstream ss;
|
|
ss << "Cannot input a tensor of type " << tensor.scalar_type()
|
|
<< " as an integral argument";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
pnode->Output(0) = at::native::item(tensor).toInt();
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::select,
|
|
aten_select,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::select(Tensor(a) self, int dim, int index) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& self = pnode->Input(0).toTensor();
|
|
const auto dim = pnode->Input(1).toInt();
|
|
const auto index = pnode->Input(2).toInt();
|
|
pnode->Output(0) = at::native::select(self, dim, index);
|
|
};
|
|
});
|
|
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::reshape_as,
|
|
aten_reshape_as,
|
|
[](Node* n) -> SROperator {
|
|
if (!n->matches(torch::schema(
|
|
"aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}
|
|
return [](ProcessedNode* pnode) {
|
|
const auto& self = pnode->Input(0).toTensor();
|
|
const auto& other = pnode->Input(1).toTensor();
|
|
pnode->Output(0) = at::native::reshape(self, other.sizes());
|
|
};
|
|
});
|
|
|
|
} // namespace torch::jit
|