#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace { constexpr auto createBorrowedIValue = c10::MaybeOwnedTraits::createBorrow; } // namespace namespace torch::jit { namespace { std::vector boxInputs(const ProcessedNode& pnode) { std::vector result; for (const auto i : c10::irange(pnode.num_inputs())) { result.push_back(pnode.Input(i)); } return result; } } // namespace C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor) bool nativeOpIsRegistered(const c10::Symbol& op_name) { const std::string name(op_name.toQualString()); return SRNativeOperatorRegistry()->Has(name); } SROperator getNativeOperation(Node* n) { auto op_name = n->kind().toQualString(); if (SRNativeOperatorRegistry()->Has(op_name)) { return SRNativeOperatorRegistry()->Create(op_name)->Generate(n); } return nullptr; } REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::TupleConstruct, prim_TupleConstruct, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::TupleConstruct)) { return nullptr; } return [](ProcessedNode* p_node) { // prepare inputs auto stack = boxInputs(*p_node); // run op auto* node = p_node->node(); const auto& type = node->output()->type()->expect(); if (type->name().has_value()) { namedTupleConstruct(stack, type, node->inputs().size()); } else { tupleConstruct(stack, node->inputs().size()); } // put output back p_node->Output(0) = std::move(stack[0]); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::TupleUnpack, prim_TupleUnpack, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::TupleUnpack)) { return nullptr; } return [](ProcessedNode* p_node) { const auto& elems = p_node->Input(0).toTupleRef().elements(); const size_t num_outputs = p_node->outputs().size(); TORCH_CHECK( num_outputs == elems.size(), "Number of outputs must match number of tuple elements.") for (size_t i = 0; i < num_outputs; ++i) { p_node->Output(i) = elems[i]; } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::DictConstruct, prim_DictConstruct, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::DictConstruct)) { return nullptr; } auto dict_type = n->output()->type()->expect(); const auto num_inputs = n->inputs().size(); TORCH_DCHECK_EQ(num_inputs % 2, 0); return [dict_type = std::move(dict_type), num_inputs, dict_size = num_inputs / 2](ProcessedNode* p_node) { auto result = c10::impl::GenericDict( dict_type->containedType(0), dict_type->containedType(1)); result.reserve(dict_size); for (size_t i = 0; i < num_inputs; i += 2) { const auto& key = p_node->Input(i); const auto& value = p_node->Input(i + 1); result.insert_or_assign(key, value); } p_node->Output(0) = result; }; }); // See [Borrowed IValue Outputs] REGISTER_NATIVE_OPERATOR_FUNCTOR( static_runtime::dict_unpack, static_runtime_dict_unpack, [](Node* n) -> SROperator { if (!sr_schema_check(n, "static_runtime::dict_unpack(...) -> ...")) { return nullptr; } return [](ProcessedNode* p_node) { DCHECK( static_cast(p_node->num_inputs() - 1) == p_node->outputs().size()); auto dict = p_node->Input(0).toGenericDict(); const auto num_inputs = p_node->num_inputs(); for (size_t i = 1; i < num_inputs; ++i) { const auto& key = p_node->Input(i); auto value = dict.find(key); TORCH_CHECK(value != dict.end(), "Key not in dict: ", key); p_node->Output(i - 1) = createBorrowedIValue(value->value()); } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::__getitem__, aten_getitem, [](Node* n) -> SROperator { if (!sr_schema_check( n, // TODO: "aten::__getitem__.str(str s, int index) -> str", "aten::__getitem__.t(t[](a) list, int idx) -> t(*)", "aten::__getitem__.Dict_str(Dict(str, t) self, str key) -> t(*)", "aten::__getitem__.Dict_int(Dict(int, t) self, int key) -> t(*)", "aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)", "aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)", "aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)", "aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)")) { return nullptr; } if (n->inputs().size() != 2) { return nullptr; } if (n->input(0)->type()->castRaw()) { return [](ProcessedNode* p_node) { auto dict = p_node->Input(0).toGenericDict(); const auto& key = p_node->Input(1); auto value = dict.find(key); TORCH_CHECK(value != dict.end(), "Key not in dict: ", key); p_node->Output(0) = value->value(); }; } else if (n->input(0)->type()->castRaw()) { return [](ProcessedNode* p_node) { const auto& list = p_node->Input(0).toList(); auto idx = p_node->Input(1).toInt(); p_node->Output(0) = getItem(list, idx); }; } // TODO(T98581096): make __getitem__ work for other container types return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::ListConstruct, prim_ListConstruct, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::ListConstruct)) { return nullptr; } return [](ProcessedNode* p_node) { // prepare inputs auto stack = boxInputs(*p_node); // run op listConstruct( stack, p_node->node()->output()->type()->expectRef(), p_node->num_inputs()); // put output back p_node->Output(0) = std::move(stack[0]); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::ListUnpack, prim_ListUnpack, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::ListUnpack)) { return nullptr; } const auto num_outputs = n->outputs().size(); return [num_outputs](ProcessedNode* p_node) { const auto list = p_node->Input(0).toListRef(); TORCH_CHECK( list.size() == num_outputs, "Expected ", num_outputs, " elements in list but got ", list.size()); for (const auto i : c10::irange(num_outputs)) { p_node->Output(i) = list[i]; } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::append, aten_append, [](Node* n) -> SROperator { if (!sr_schema_check( n, "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)")) { return nullptr; } return [](ProcessedNode* p_node) { auto list = p_node->Input(0).toList(); list.push_back(p_node->Input(1)); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::list, aten_list, [](Node* n) -> SROperator { if (n->matches(torch::schema("aten::list(str t) -> str[]"))) { return [](ProcessedNode* p_node) { const auto str = p_node->Input(0).toStringRef(); c10::List chars; chars.reserve(str.size()); for (auto c : str) { chars.emplace_back(1, c); } p_node->Output(0) = std::move(chars); }; } if (n->matches(torch::schema("aten::list.t(t[] l) -> t[]"))) { return [](ProcessedNode* p_node) { const auto input = p_node->Input(0).toList(); p_node->Output(0) = input.copy(); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::numel, aten_numel, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::numel(Tensor self) -> int")) { return nullptr; } return [](ProcessedNode* p_node) { const auto& arg = p_node->Input(0).toTensor(); p_node->Output(0) = arg.numel(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::cpu, aten_cpu, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::cpu(Tensor self) -> Tensor")) { return nullptr; } return [](ProcessedNode* p_node) { const auto& arg = p_node->Input(0).toTensor(); p_node->Output(0) = arg.cpu(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::__range_length, aten_range_length, [](Node* n) -> SROperator { if (!sr_schema_check( n, "aten::__range_length(int lo, int hi, int step) -> int")) { return nullptr; } return [](ProcessedNode* p_node) { auto lo = p_node->Input(0).toInt(); auto hi = p_node->Input(1).toInt(); auto step = p_node->Input(2).toInt(); // error handling when step_val == 0 during runtime if (step == 0) { throw std::runtime_error("range() arg 3 must not be zero"); } if (step > 0 && lo < hi) { p_node->Output(0) = 1 + (hi - 1 - lo) / step; } else if (step < 0 && lo > hi) { p_node->Output(0) = 1 + (lo - 1 - hi) / (0 - step); } else { p_node->Output(0) = 0; } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::index_put, aten_index_put, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::index_put(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor")) || n->matches(torch::schema( "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"))) { return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto& indices = at::native::toListOfOptionalTensors(p_node->Input(1).toListRef()); const auto& values = p_node->Input(2).toTensor(); const auto accumulate = p_node->Input(3).toBool(); p_node->Output(0) = at::native::index_put(self, indices, values, accumulate); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::item, aten_item, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::item(Tensor self) -> Scalar")) { return nullptr; } return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); p_node->Output(0) = at::native::item(self); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::GetAttr, prim_GetAttr, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::GetAttr)) { return nullptr; } return [](ProcessedNode* p_node) { auto& module = p_node->Input(0).toObjectRef(); Node* node = p_node->node(); const auto& type = node->input()->type()->expectRef(); const auto& field = node->s(attr::name); const auto slot = type.getAttributeSlot(field); p_node->Output(0) = module.getSlot(slot); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::SetAttr, prim_SetAttr, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::SetAttr)) { return nullptr; } return [](ProcessedNode* p_node) { auto& module = p_node->Input(0).toObjectRef(); Node* node = p_node->node(); const auto& type = node->inputs()[0]->type()->expectRef(); const auto& field = node->s(attr::name); const auto slot = type.getAttributeSlot(field); module.setSlot(slot, p_node->Input(1)); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::transpose, aten_transpose, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_i = p_node->Input(1).toInt(); const auto in2_i = p_node->Input(2).toInt(); p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_i = p_node->Input(1).toInt(); const auto in2_i = p_node->Input(2).toInt(); p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::permute, aten_permute, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_iv = p_node->Input(1).toDimVector(); p_node->Output(0) = at::native::permute(in0_t, in1_iv); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::reshape, aten_reshape, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_iv = p_node->Input(1).toDimVector(); p_node->Output(0) = at::native::reshape(in0_t, in1_iv); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_i = p_node->Input(1).toInt(); const auto in2_i = p_node->Input(2).toOptional(); const auto in3_i = p_node->Input(3).toOptional(); const auto in4_i = p_node->Input(4).toInt(); p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) && !n->matches(torch::schema( "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); // self const auto dim = p_node->Input(1).toInt(); // dim int64_t start = 0; if (p_node->Input(2).isScalar()) { start = p_node->Input(2).toInt(); } else { auto& t = p_node->Input(2).toTensor(); start = t.item(); } const auto length = p_node->Input(3).toInt(); // length TORCH_CHECK( self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); auto cur_size = self.sizes()[dim]; if (start != cur_size && start < 0) { // start being the end is valid, but // not a valid dim specification. start = at::maybe_wrap_dim(start, cur_size); } TORCH_CHECK( length >= 0 && start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) { return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto& in1_t = p_node->Input(1).toTensor(); const auto in2_i = p_node->Input(2).toBool(); const auto in3_i = p_node->Input(3).toBool(); const auto in4_o = p_node->Input(4).toOptional(); p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o); }; } if (n->matches(torch::schema( "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) { return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_i = p_node->Input(1).toScalarType(); const auto in2_i = p_node->Input(2).toBool(); const auto in3_i = p_node->Input(3).toBool(); const auto in4_o = p_node->Input(4).toOptional(); p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o); }; } if (n->matches(torch::schema( "aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) { return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); const auto in1_i = p_node->Input(1).toOptional(); const auto in2_i = p_node->Input(2).toBool(); const auto in3_i = p_node->Input(3).toBool(); // To mimick the behavior of the JIT interpreter, if both dtype // and copy are not set, we return self. Otherwise, we assume // that dtype is set. if (!in1_i && !in3_i) { p_node->Output(0) = in0_t; } else { TORCH_CHECK( in1_i, "dytpe cannot be None when copy is True for aten::to.prim_dtype"); p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i); } }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::detach, aten_detach, [](Node* n) -> SROperator { if (!n->matches( torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); p_node->Output(0) = at::native::alias(in0_t); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::expand_as, aten_expand_as, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto& other = p_node->Input(1).toTensor(); p_node->Output(0) = self.expand(other.sizes()); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::isinstance, prim_isinstance, [](Node* n) -> SROperator { if (!n->matches( torch::schema("prim::isinstance(Any to_check) -> bool"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { auto input_type = p_node->Input(0).type(); auto* node = p_node->node(); const std::vector& candidates = node->tys(attr::types); for (const auto& candidate_type : candidates) { if (input_type->isSubtypeOf(*candidate_type)) { p_node->Output(0) = true; return; } } p_node->Output(0) = false; }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::TypeCheck, prim_TypeCheck, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::TypeCheck)) { return nullptr; } return [](ProcessedNode* p_node) { auto* node = p_node->node(); const size_t num_inputs = node->inputs().size(); TORCH_INTERNAL_ASSERT( num_inputs && num_inputs + 1 == node->outputs().size()); const auto& expected_types = node->tys(attr::types); for (size_t i = 0; i < num_inputs; i++) { p_node->Output(i) = p_node->Input(i); } for (size_t i = 0; i < num_inputs; i++) { auto& input_tensor = p_node->Input(i).toTensor(); auto* expected_type = expected_types[i]->castRaw(); if (input_tensor.defined() && !expected_type->matchTensor(input_tensor)) { p_node->Output(num_inputs) = false; return; } } p_node->Output(num_inputs) = true; }; }); // See [Borrowed IValue Outputs] REGISTER_NATIVE_OPERATOR_FUNCTOR( static_runtime::VarTupleUnpack, static_runtime_VarTupleUnpack, [](Node* n) -> SROperator { if (!sr_schema_check(n, "static_runtime::VarTupleUnpack(...) -> ...")) { return nullptr; } return [](ProcessedNode* pnode) { size_t output_idx = 0; for (const auto idx : c10::irange(pnode->num_inputs())) { const auto& tuple = pnode->Input(idx); for (auto& elem : tuple.toTupleRef().elements()) { pnode->Output(output_idx) = createBorrowedIValue(elem); ++output_idx; } } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::view, aten_view, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::view(Tensor(a) self, int[] size) -> (Tensor(a))"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& input = p_node->Input(0).toTensor(); const auto size = p_node->Input(1).toIntList(); p_node->Output(0) = at::native::view(input, size.vec()); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::size, aten_size, [](Node* n) -> SROperator { if (n->matches( torch::schema("aten::size(Tensor self, int dim) -> int"))) { return [](ProcessedNode* p_node) { const auto& input = p_node->Input(0).toTensor(); auto dim = p_node->Input(1).toInt(); const auto ndim = input.dim(); if (dim < 0 || dim >= ndim) { dim = c10::maybe_wrap_dim(dim, ndim); } p_node->Output(0) = input.sizes()[dim]; }; } if (n->matches(torch::schema("aten::size(Tensor self) -> int[]"))) { return [](ProcessedNode* p_node) { const auto& input = p_node->Input(0).toTensor(); p_node->Output(0) = input.sizes(); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::squeeze, aten_squeeze, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto dim = p_node->Input(1).toInt(); p_node->Output(0) = at::native::squeeze(self, dim); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::split, aten_split, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::split(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]"))) { return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto split_size = p_node->Input(1).toInt(); const auto dim = p_node->Input(2).toInt(); p_node->Output(0) = at::native::split(self, split_size, dim); }; } if (n->matches(torch::schema( "aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) { return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto& split_sizes = p_node->Input(1).toIntList(); const auto dim = p_node->Input(2).toInt(); p_node->Output(0) = at::native::split_with_sizes(self, split_sizes.vec(), dim); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::split_with_sizes, aten_split_with_sizes, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]")) && !n->matches(torch::schema( "aten::split_with_sizes(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> (Tensor[])"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& self = p_node->Input(0).toTensor(); const auto& split_sizes = p_node->Input(1).toIntList(); const auto dim = p_node->Input(2).toInt(); p_node->Output(0) = at::native::split_with_sizes(self, split_sizes.vec(), dim); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( static_runtime::select_tensor, aten_select_tensor, [](Node* n) -> SROperator { if (!sr_schema_check( n, "static_runtime::select_tensor(Tensor(a) a, Tensor(b) b, bool use_b) -> Tensor(a|b)")) { return nullptr; } return [](ProcessedNode* p_node) { const auto did_copy = p_node->Input(2).toBool(); DCHECK(p_node->Input(0).isTensor()); DCHECK(!did_copy || p_node->Input(1).isTensor()); const IValue& assignFrom = did_copy ? p_node->Input(1) : p_node->Input(0); // Create an IValue that borrows the input Tensor in order to // save a refcount increment here and decrement in // MemoryPlanner::deallocate. MemoryPlanner knows about this // and will safely clean it up by using the corresponding // destroyBorrow method. TORCH_DCHECK_NE(&assignFrom, &p_node->Output(0)); // MemoryPlanner should have cleaned this up! DCHECK(p_node->Output(0).isNone()); p_node->Output(0) = IValue(c10::MaybeOwnedTraits::createBorrow( assignFrom.toTensor())); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::mul, aten_mul, [](Node* n) -> SROperator { if (!n->matches( torch::schema("aten::mul.left_t(t[] l, int n) -> (t[])"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* pnode) { const auto& list = pnode->Input(0).toList(); const auto n = pnode->Input(1).toInt(); auto list_type = list.elementType(); auto ret = c10::impl::GenericList(list_type); ret.reserve(list.size() * n); for (const auto i : c10::irange(n)) { (void)i; for (const auto& ival : list) { ret.push_back(ival); } } pnode->Output(0) = ret; }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::sub, aten_sub, [](Node* n) -> SROperator { if (!n->matches(torch::schema("aten::sub.int(int a, int b) -> (int)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* pnode) { const auto a = pnode->Input(0).toInt(); const auto b = pnode->Input(1).toInt(); pnode->Output(0) = a - b; }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::add, aten_add, [](Node* n) -> SROperator { if (n->matches(torch::schema("aten::add.t(t[] a, t[] b) -> (t[])"))) { return [](ProcessedNode* pnode) { const auto& a = pnode->Input(0).toList(); const auto& b = pnode->Input(1).toList(); auto ret = a.copy(); ret.append(b); pnode->Output(0) = ret; }; } if (n->matches(torch::schema("aten::add.int(int a, int b) -> (int)"))) { return [](ProcessedNode* pnode) { const auto a = pnode->Input(0).toInt(); const auto b = pnode->Input(1).toInt(); pnode->Output(0) = a + b; }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) { return [](ProcessedNode* pnode) { const auto& a = pnode->Input(0).toTensor(); const auto& b = pnode->Input(1).toIntVector(); const auto c = pnode->Input(2).toInt(); pnode->Output(0) = at::native::tensor_split(a, b, c); }; } if (n->matches(torch::schema( "aten::tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) { return [](ProcessedNode* pnode) { const auto& a = pnode->Input(0).toTensor(); const auto b = pnode->Input(1).toSymInt(); const auto c = pnode->Input(2).toInt(); pnode->Output(0) = at::native::tensor_split_sections_symint(a, b, c); }; } if (n->matches(torch::schema( "aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) { return [](ProcessedNode* pnode) { const auto& a = pnode->Input(0).toTensor(); const auto& b = pnode->Input(1).toTensor(); const auto c = pnode->Input(2).toInt(); pnode->Output(0) = at::native::tensor_split(a, b, c); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::Int, aten_Int, [](Node* n) -> SROperator { if (!n->matches(torch::schema("aten::Int(Tensor a) -> int"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); pnode->Output(0) = at::native::item(input).toInt(); }; }); // See [Create owned refs for special values] REGISTER_NATIVE_OPERATOR_FUNCTOR( static_runtime::create_owned_ref, static_runtime_create_owned_ref, [](Node* n) -> SROperator { if (!sr_schema_check(n, "static_runtime::create_owned_ref(...) -> ...")) { return nullptr; } return [](ProcessedNode* p_node) { p_node->Output(0) = p_node->Input(0); }; }); namespace { bool outputsEmpty(const Block* block) { return block->outputs().size() == 1 && block->outputs().at(0)->mustBeNone(); } bool blockEmpty(const Block* block) { return block->nodes().begin() == block->nodes().end(); } enum class BlockRunPlan : int8_t { kRunOnlyTrueBlock, kRunOnlyFalseBlock, kRunBothBlocks, kRunNeitherBlock, }; } // namespace REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::If, prim_If, [](Node* node) -> SROperator { if (!sr_schema_check_kind(node, prim::If)) { return nullptr; } TORCH_DCHECK_EQ(node->blocks().size(), 2); const Block* true_block = node->blocks().at(0); const Block* false_block = node->blocks().at(1); const bool true_block_returns_empty = outputsEmpty(true_block); const bool false_block_returns_empty = outputsEmpty(false_block); BlockRunPlan block_run_plan = BlockRunPlan::kRunNeitherBlock; if (true_block_returns_empty && false_block_returns_empty) { const bool false_block_is_empty = blockEmpty(false_block); const bool true_block_is_empty = blockEmpty(true_block); if (false_block_is_empty && !true_block_is_empty) { block_run_plan = BlockRunPlan::kRunOnlyTrueBlock; } else if (!false_block_is_empty && true_block_is_empty) { block_run_plan = BlockRunPlan::kRunOnlyFalseBlock; } else if (false_block_is_empty && true_block_is_empty) { block_run_plan = BlockRunPlan::kRunNeitherBlock; } else { block_run_plan = BlockRunPlan::kRunBothBlocks; } } else { block_run_plan = BlockRunPlan::kRunBothBlocks; } switch (block_run_plan) { case BlockRunPlan::kRunBothBlocks: return [](ProcessedNode* p_node) { auto condition = p_node->Input(0).toBool(); auto* metadata = p_node->metadata(); DCHECK(metadata); auto& block_runners = metadata->block_runners(); TORCH_DCHECK_EQ(block_runners.size(), 2); auto& runner = block_runners[!condition]; auto output = runner({}); // If we are returning a tuple, we are either returning // multiple unpacked values or all of the values wrapped // in a single tuple. The second condition handles the // the latter case. if (!output.isTuple() || p_node->num_outputs() == 1) { p_node->Output(0) = std::move(output); return; } auto& elems = output.toTupleRef().elements(); TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs()); for (const auto i : c10::irange(elems.size())) { p_node->Output(i) = elems[i]; } }; case BlockRunPlan::kRunOnlyTrueBlock: return [](ProcessedNode* p_node) { auto condition = p_node->Input(0).toBool(); auto* metadata = p_node->metadata(); DCHECK(metadata); auto& block_runners = metadata->block_runners(); TORCH_DCHECK_EQ(block_runners.size(), 2); if (condition) { auto output = block_runners.front()({}); DCHECK(output.isNone()); } }; case BlockRunPlan::kRunOnlyFalseBlock: return [](ProcessedNode* p_node) { auto condition = p_node->Input(0).toBool(); auto* metadata = p_node->metadata(); DCHECK(metadata); auto& block_runners = metadata->block_runners(); TORCH_DCHECK_EQ(block_runners.size(), 2); if (!condition) { auto output = block_runners.back()({}); DCHECK(output.isNone()); } }; case BlockRunPlan::kRunNeitherBlock: return [](ProcessedNode*) {}; } return [](ProcessedNode*) {}; }); namespace { std::vector collectLoopSubBlockInputs(const ProcessedNode& p_node) { const auto num_inputs = p_node.num_inputs(); TORCH_DCHECK_GE(num_inputs, 2); // The first two inputs to the loop node are the max trip count // and initial condition. We don't collect them here, since those // are not inputs for the sub-block. const auto num_args = num_inputs - 2; std::vector result; result.reserve(num_args + 1); // First argument to the loop sub-block is always the loop counter, initially // zero. result.emplace_back(0); for (const auto i : c10::irange(num_args)) { result.push_back(p_node.Input(2 + i)); } return result; } } // namespace namespace { /* ForkedSubgraphSRLauncher is responsible for the execution of forked subgraph on new instance of static runtime. Once the execution is completed, future is marked as complete to indicate aten::wait() to proceed */ class TORCH_API ForkedSubgraphSRLauncher { public: ForkedSubgraphSRLauncher( std::shared_ptr smodule, std::vector args, c10::intrusive_ptr future, TaskLauncher launcher) : smodule_(std::move(smodule)), args_(std::move(args)), future_(std::move(future)), launcher_(std::move(launcher)) {} void operator()() { try { StaticRuntime runtime(*smodule_); auto future_subgraph = runtime.runAsync(args_, {}, launcher_); future_subgraph->waitAndThrow(); future_->markCompleted(future_subgraph->value()); } catch (const std::exception& e) { future_->setErrorIfNeeded( std::make_exception_ptr(c10::ivalue::Future::FutureError(e.what()))); } } private: std::shared_ptr smodule_; std::vector args_; c10::intrusive_ptr future_; torch::jit::TaskLauncher launcher_; }; /* helper function to create a future on return type of the graph outputs. This function is utilized by prim::fork and aten::wait operations for async execution of subgraphs */ c10::intrusive_ptr createFutureTypeFromGraphOutput( const std::shared_ptr& graph) { TypePtr return_type_; if (graph->outputs().size() == 1) { return_type_ = graph->outputs().at(0)->type(); } else { return_type_ = TupleType::create( fmap(graph->outputs(), [](const Value* v) { return v->type(); })); } c10::intrusive_ptr future = c10::make_intrusive(return_type_); return future; } } // namespace /* prim::fork forks the execution of a subgraph. It returns a future on which the corresponding aten::wait op waits until future is marked complete Current implementation creates a instance of StaticModule uses it to create StaticRuntime instances on the fly during runtime to handle the execution of forked subgraph. Async execution is handled by aten::ParallelThreadPoolNative threadpool. */ REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::fork, prim_Fork, [](Node* node) -> SROperator { if (!sr_schema_check_kind(node, prim::fork)) { return nullptr; } auto forkedGraph = node->g(attr::Subgraph); Inline(*forkedGraph); auto sr_metadata = node->ival(getStaticRuntimeMetadataSymbol()) .toCustomClass(); auto smodule = std::make_shared(forkedGraph, sr_metadata->get_opts()); return [forkedGraph = std::move(forkedGraph), smodule = std::move(smodule)](ProcessedNode* p_node) { std::vector args; args.reserve(p_node->num_inputs()); for (const auto i : c10::irange(p_node->num_inputs())) { args.push_back(p_node->Input(i)); } c10::intrusive_ptr future = createFutureTypeFromGraphOutput(forkedGraph); p_node->Output(0) = future; auto* metadata = p_node->metadata(); DCHECK(metadata); auto* launcher = metadata->launcher(); DCHECK(launcher); ForkedSubgraphSRLauncher runtime_launcher( smodule, args, future, *launcher); (*launcher)(std::move(runtime_launcher)); }; }); /* aten::wait waits on the future (present in corresponding fork) to be executed. Once the execution is complete, the future is marked completed and wait execution continues. */ REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::wait, aten_Wait, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::wait(Future(t) self) -> t")) { return nullptr; } return [](ProcessedNode* p_node) { TORCH_INTERNAL_ASSERT(p_node->Input(0).isFuture()); auto future = p_node->Input(0).toFuture(); // blocking call: waiting for the future to be completed future->waitAndThrow(); TORCH_INTERNAL_ASSERT(future->completed()); TORCH_INTERNAL_ASSERT(!future->hasError()); TORCH_INTERNAL_ASSERT(future->hasValue()); if (!future->value().isTuple()) { p_node->Output(0) = future->value(); return; } auto& elems = future->value().toTupleRef().elements(); TORCH_DCHECK_EQ(elems.size(), p_node->num_outputs()); for (const auto i : c10::irange(elems.size())) { p_node->Output(i) = elems[i]; } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::Loop, prim_Loop, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::Loop)) { return nullptr; } return [](ProcessedNode* p_node) { const auto max_trip_count = p_node->Input(0).toInt(); auto condition = p_node->Input(1).toBool(); auto* metadata = p_node->metadata(); DCHECK(metadata); auto& block_runners = metadata->block_runners(); TORCH_DCHECK_EQ(block_runners.size(), 1); auto& runner = block_runners[0]; auto args = collectLoopSubBlockInputs(*p_node); int64_t loop_count = 0; while (condition && loop_count < max_trip_count) { auto output = runner(args); if (output.isTuple()) { auto& elems = output.toTupleRef().elements(); DCHECK(elems.size() == args.size()); for (const auto i : c10::irange(1, args.size())) { args[i] = elems[i]; } condition = elems[0].toBool(); } else { condition = output.toBool(); } args[0] = ++loop_count; } const auto num_outputs = p_node->num_outputs(); TORCH_DCHECK_EQ(args.size(), num_outputs + 1); for (const auto i : c10::irange(num_outputs)) { p_node->Output(i) = std::move(args[i + 1]); } }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::CreateObject, prim_CreateObject, [](Node* node) -> SROperator { if (!sr_schema_check_kind(node, prim::CreateObject)) { return nullptr; } auto class_type = node->output()->type()->expect(); return [class_type = std::move(class_type)](ProcessedNode* pnode) { pnode->Output(0) = c10::ivalue::Object::create( c10::StrongTypePtr(class_type->compilation_unit(), class_type), class_type->numAttributes()); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::TupleIndex, prim_TupleIndex, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::TupleIndex)) { return nullptr; } return [](ProcessedNode* pnode) { const auto& elems = pnode->Input(0).toTupleRef().elements(); using c10::ssize; const auto num_elems = ssize(elems); const auto idx = pnode->Input(1).toInt(); const auto norm_idx = normalizeIndex(idx, num_elems); if (norm_idx < 0 || norm_idx >= num_elems) { // Use std::runtime_error instead of c10::Error to be consistent with // JIT throw std::out_of_range("Tuple index out of range"); } pnode->Output(0) = elems[norm_idx]; }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::RaiseException, prim_RaiseException, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::RaiseException)) { return nullptr; } return [](ProcessedNode* pnode) { const auto& message = pnode->Input(0).toStringRef(); throw std::runtime_error(message); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::Uninitialized, prim_Uninitialized, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::Uninitialized)) { return nullptr; } return [](ProcessedNode* pnode) { pnode->Output(0) = IValue::uninitialized(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::format, aten_format, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::format(str self, ...) -> str")) { return nullptr; } TORCH_CHECK(!n->inputs().empty()); return [](ProcessedNode* pnode) { const auto num_inputs = pnode->num_inputs(); auto stack = boxInputs(*pnode); format(stack, num_inputs); TORCH_DCHECK_EQ(stack.size(), 1); pnode->Output(0) = std::move(stack[0]); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::device, prim_device, [](Node* n) -> SROperator { if (!sr_schema_check(n, "prim::device(Tensor a) -> Device")) { return nullptr; } return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); pnode->Output(0) = input.device(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::dtype, prim_dtype, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::dtype)) { return nullptr; } return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); pnode->Output(0) = static_cast(input.scalar_type()); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::dim, aten_dim, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::dim(Tensor self) -> int")) { return nullptr; } return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); pnode->Output(0) = input.dim(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::__not__, aten_not, [](Node* n) -> SROperator { if (!sr_schema_check(n, "aten::__not__(bool self) -> bool")) { return nullptr; } return [](ProcessedNode* pnode) { auto input = pnode->Input(0).toBool(); pnode->Output(0) = !input; }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::Bool, aten_Bool, [](Node* n) -> SROperator { if (n->matches(torch::schema("aten::Bool.Tensor(Tensor a) -> bool"))) { return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); pnode->Output(0) = at::native::is_nonzero(input); }; } if (n->matches(torch::schema("aten::Bool.int(int a) -> bool"))) { return [](ProcessedNode* pnode) { const auto input = pnode->Input(0).toInt(); pnode->Output(0) = static_cast(input); }; } if (n->matches(torch::schema("aten::Bool.float(float a) -> bool"))) { return [](ProcessedNode* pnode) { const auto input = pnode->Input(0).toDouble(); pnode->Output(0) = static_cast(input); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::is_cuda, prim_is_cuda, [](Node* n) -> SROperator { if (!sr_schema_check(n, "prim::is_cuda(Tensor a) -> bool")) { return nullptr; } return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); pnode->Output(0) = input.is_cuda(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::tolist, prim_tolist, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::tolist)) { return nullptr; } return [](ProcessedNode* pnode) { const auto& input = pnode->Input(0).toTensor(); const auto dim = pnode->Input(1).toInt(); const auto elem_type = pnode->Input(2).toInt(); std::vector stack{input, dim, elem_type}; toList(stack); TORCH_DCHECK_EQ(stack.size(), 1); pnode->Output(0) = std::move(stack[0]); }; }); // See [Borrowed IValue Outputs] REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::IfThenElse, prim_IfThenElse, [](Node* n) -> SROperator { if (!sr_schema_check_kind(n, prim::IfThenElse)) { return nullptr; } return [](ProcessedNode* pnode) { const auto condition = pnode->Input(0).toBool(); pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1)) : createBorrowedIValue(pnode->Input(2)); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::len, aten_len, [](Node* n) -> SROperator { if (n->matches(torch::schema("aten::len.t(t[] a) -> int")) || n->matches(torch::schema("aten::len.any(Any[] a) -> int"))) { return [](ProcessedNode* pnode) { const auto list = pnode->Input(0).toListRef(); const int64_t size = list.size(); pnode->Output(0) = size; }; } if (n->matches(torch::schema("aten::len.Tensor(Tensor t) -> int"))) { return [](ProcessedNode* pnode) { const auto& t = pnode->Input(0).toTensor(); TORCH_CHECK(t.dim() > 0); pnode->Output(0) = t.sizes()[0]; }; } if (n->matches(torch::schema("aten::len.str(str s) -> int"))) { return [](ProcessedNode* pnode) { const auto& string = pnode->Input(0).toStringRef(); pnode->Output(0) = static_cast(string.size()); }; } if (n->matches( torch::schema("aten::len.Dict_str(Dict(str, t) self) -> int")) || n->matches( torch::schema("aten::len.Dict_int(Dict(int, t) self) -> int")) || n->matches(torch::schema( "aten::len.Dict_bool(Dict(bool, t) self) -> int")) || n->matches(torch::schema( "aten::len.Dict_float(Dict(float, t) self) -> int")) || n->matches(torch::schema( "aten::len.Dict_complex(Dict(complex, t) self) -> int")) || n->matches(torch::schema( "aten::len.Dict_Tensor(Dict(Tensor, t) self) -> int"))) { return [](ProcessedNode* pnode) { const auto& dict = pnode->Input(0).toGenericDict(); pnode->Output(0) = static_cast(dict.size()); }; } LogAndDumpSchema(n); return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::IntImplicit, aten_IntImplicit, [](Node* n) -> SROperator { if (!n->matches(torch::schema("aten::IntImplicit(Tensor a) -> int"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* pnode) { const auto& tensor = pnode->Input(0).toTensor(); // JIT does a check for requires_grad, but we skip it here since SR is // inference only if (!tensor.sizes().empty()) { throw std::runtime_error( "Cannot convert a tensor of dimension > 0 to scalar"); } if (!isIntegralType(tensor.scalar_type(), /*includeBool=*/false)) { std::stringstream ss; ss << "Cannot input a tensor of type " << tensor.scalar_type() << " as an integral argument"; throw std::runtime_error(ss.str()); } pnode->Output(0) = at::native::item(tensor).toInt(); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::select, aten_select, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::select(Tensor(a) self, int dim, int index) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* pnode) { const auto& self = pnode->Input(0).toTensor(); const auto dim = pnode->Input(1).toInt(); const auto index = pnode->Input(2).toInt(); pnode->Output(0) = at::native::select(self, dim, index); }; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::reshape_as, aten_reshape_as, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* pnode) { const auto& self = pnode->Input(0).toTensor(); const auto& other = pnode->Input(1).toTensor(); pnode->Output(0) = at::native::reshape(self, other.sizes()); }; }); } // namespace torch::jit