mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] Add native ops: aten::index_put, aten::item, aten::tensor_split (#79065)
Summary: This adds the pytorch operators that are currently missing in non-ads models from c2->pt mitigation: aten::index_put, aten::item, aten::tensor_split Test Plan: buck run mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest Differential Revision: D36984961 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79065 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bc1b9e039
commit
8d7fcfa8f1
@ -2482,6 +2482,56 @@ TEST(StaticRuntime, LinalgNorm_StringOrd) {
|
||||
testStaticRuntime(linalg_norm_ord_str, args0, args1);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Index_Put) {
|
||||
const auto index_put_str = R"JIT(
|
||||
def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
|
||||
return torch.index_put(a, indices, values, accumulate).clone()
|
||||
)JIT";
|
||||
|
||||
auto a = at::randn({2});
|
||||
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
|
||||
auto values_a = at::randn({1});
|
||||
|
||||
std::vector<IValue> args0{a, indicies_a, values_a, false};
|
||||
testStaticRuntime(index_put_str, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Item) {
|
||||
const auto item_str = R"JIT(
|
||||
def forward(self, a: Tensor):
|
||||
return torch.item(a)
|
||||
)JIT";
|
||||
|
||||
auto a = at::randn({1});
|
||||
|
||||
std::vector<IValue> args0{a};
|
||||
testStaticRuntime(item_str, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Tensor_Split) {
|
||||
const auto tensor_split_str1 = R"JIT(
|
||||
def forward(self, a: Tensor, sections: int, dim: int):
|
||||
return torch.tensor_split(a, sections, dim)
|
||||
)JIT";
|
||||
std::vector<IValue> args1{at::randn({8}), 3, 0};
|
||||
|
||||
const auto tensor_split_str2 = R"JIT(
|
||||
def forward(self, a: Tensor, sections: Tensor, dim: int):
|
||||
return torch.tensor_split(a, sections, dim)
|
||||
)JIT";
|
||||
std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
|
||||
|
||||
const auto tensor_split_str3 = R"JIT(
|
||||
def forward(self, a: Tensor, indicies: List[int], dim: int):
|
||||
return torch.tensor_split(a, indicies, dim)
|
||||
)JIT";
|
||||
std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
|
||||
|
||||
testStaticRuntime(tensor_split_str1, args1);
|
||||
testStaticRuntime(tensor_split_str2, args2);
|
||||
testStaticRuntime(tensor_split_str3, args3);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Cat) {
|
||||
const std::string cat_script = R"IR(
|
||||
graph(%a: Tensor, %b: Tensor, %dim: int):
|
||||
|
Reference in New Issue
Block a user