[static runtime] Add native ops: aten::index_put, aten::item, aten::tensor_split (#79065)

Summary: This adds the pytorch operators that are currently missing in non-ads models from c2->pt mitigation: aten::index_put, aten::item, aten::tensor_split

Test Plan: buck run mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest

Differential Revision: D36984961

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79065
Approved by: https://github.com/davidberard98
This commit is contained in:
Hui Guo
2022-06-15 19:15:31 +00:00
committed by PyTorch MergeBot
parent 0bc1b9e039
commit 8d7fcfa8f1
2 changed files with 108 additions and 0 deletions

View File

@ -2482,6 +2482,56 @@ TEST(StaticRuntime, LinalgNorm_StringOrd) {
testStaticRuntime(linalg_norm_ord_str, args0, args1);
}
TEST(StaticRuntime, Index_Put) {
const auto index_put_str = R"JIT(
def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
return torch.index_put(a, indices, values, accumulate).clone()
)JIT";
auto a = at::randn({2});
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
auto values_a = at::randn({1});
std::vector<IValue> args0{a, indicies_a, values_a, false};
testStaticRuntime(index_put_str, args0);
}
TEST(StaticRuntime, Item) {
const auto item_str = R"JIT(
def forward(self, a: Tensor):
return torch.item(a)
)JIT";
auto a = at::randn({1});
std::vector<IValue> args0{a};
testStaticRuntime(item_str, args0);
}
TEST(StaticRuntime, Tensor_Split) {
const auto tensor_split_str1 = R"JIT(
def forward(self, a: Tensor, sections: int, dim: int):
return torch.tensor_split(a, sections, dim)
)JIT";
std::vector<IValue> args1{at::randn({8}), 3, 0};
const auto tensor_split_str2 = R"JIT(
def forward(self, a: Tensor, sections: Tensor, dim: int):
return torch.tensor_split(a, sections, dim)
)JIT";
std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
const auto tensor_split_str3 = R"JIT(
def forward(self, a: Tensor, indicies: List[int], dim: int):
return torch.tensor_split(a, indicies, dim)
)JIT";
std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
testStaticRuntime(tensor_split_str1, args1);
testStaticRuntime(tensor_split_str2, args2);
testStaticRuntime(tensor_split_str3, args3);
}
TEST(StaticRuntime, Cat) {
const std::string cat_script = R"IR(
graph(%a: Tensor, %b: Tensor, %dim: int):