[static runtime] Add native ops: aten::index_put, aten::item, aten::tensor_split (#79065)

Summary: This adds the pytorch operators that are currently missing in non-ads models from c2->pt mitigation: aten::index_put, aten::item, aten::tensor_split

Test Plan: buck run mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest

Differential Revision: D36984961

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79065
Approved by: https://github.com/davidberard98
This commit is contained in:
Hui Guo
2022-06-15 19:15:31 +00:00
committed by PyTorch MergeBot
parent 0bc1b9e039
commit 8d7fcfa8f1
2 changed files with 108 additions and 0 deletions

View File

@ -2482,6 +2482,56 @@ TEST(StaticRuntime, LinalgNorm_StringOrd) {
testStaticRuntime(linalg_norm_ord_str, args0, args1);
}
TEST(StaticRuntime, Index_Put) {
const auto index_put_str = R"JIT(
def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
auto a = at::randn({2});
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
auto values_a = at::randn({1});
std::vector<IValue> args0{a, indicies_a, values_a, false};
testStaticRuntime(index_put_str, args0);
}
TEST(StaticRuntime, Item) {
const auto item_str = R"JIT(
def forward(self, a: Tensor):
return torch.item(a)
)JIT";
auto a = at::randn({1});
std::vector<IValue> args0{a};
testStaticRuntime(item_str, args0);
}
TEST(StaticRuntime, Tensor_Split) {
const auto tensor_split_str1 = R"JIT(
def forward(self, a: Tensor, sections: int, dim: int):
return torch.tensor_split(a, sections, dim)
)JIT";
std::vector<IValue> args1{at::randn({8}), 3, 0};
const auto tensor_split_str2 = R"JIT(
def forward(self, a: Tensor, sections: Tensor, dim: int):
return torch.tensor_split(a, sections, dim)
)JIT";
std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
const auto tensor_split_str3 = R"JIT(
def forward(self, a: Tensor, indicies: List[int], dim: int):
return torch.tensor_split(a, indicies, dim)
)JIT";
std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
testStaticRuntime(tensor_split_str1, args1);
testStaticRuntime(tensor_split_str2, args2);
testStaticRuntime(tensor_split_str3, args3);
}
TEST(StaticRuntime, Cat) {
const std::string cat_script = R"IR(
graph(%a: Tensor, %b: Tensor, %dim: int):

View File

@ -200,6 +200,30 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
};
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::index_put,
aten_index_put,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toOptionalTensorList();
const auto& values = p_node->Input(2).toTensor();
const auto accumulate = p_node->Input(3).toBool();
p_node->Output(0) =
at::native::index_put(self, indices, values, accumulate);
};
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::item,
aten_item,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
p_node->Output(0) = at::native::item(self);
};
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::GetAttr,
prim_GetAttr,
@ -694,6 +718,40 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
return nullptr;
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) {
return [](ProcessedNode* pnode) {
const auto& a = pnode->Input(0).toTensor();
const auto& b = pnode->Input(1).toIntVector();
const auto c = pnode->Input(2).toInt();
pnode->Output(0) = at::native::tensor_split(a, b, c);
};
}
if (n->matches(torch::schema(
"tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) {
return [](ProcessedNode* pnode) {
const auto& a = pnode->Input(0).toTensor();
const auto b = pnode->Input(1).toInt();
const auto c = pnode->Input(2).toInt();
pnode->Output(0) = at::native::tensor_split(a, b, c);
};
}
if (n->matches(torch::schema(
"tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) {
return [](ProcessedNode* pnode) {
const auto& a = pnode->Input(0).toTensor();
const auto& b = pnode->Input(1).toTensor();
const auto c = pnode->Input(2).toInt();
pnode->Output(0) = at::native::tensor_split(a, b, c);
};
}
LogAndDumpSchema(n);
return nullptr;
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::Int,
aten_Int,