mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] Add native ops: aten::index_put, aten::item, aten::tensor_split (#79065)
Summary: This adds the pytorch operators that are currently missing in non-ads models from c2->pt mitigation: aten::index_put, aten::item, aten::tensor_split Test Plan: buck run mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest Differential Revision: D36984961 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79065 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bc1b9e039
commit
8d7fcfa8f1
@ -2482,6 +2482,56 @@ TEST(StaticRuntime, LinalgNorm_StringOrd) {
|
||||
testStaticRuntime(linalg_norm_ord_str, args0, args1);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Index_Put) {
|
||||
const auto index_put_str = R"JIT(
|
||||
def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
|
||||
return torch.index_put(a, indices, values, accumulate).clone()
|
||||
)JIT";
|
||||
|
||||
auto a = at::randn({2});
|
||||
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
|
||||
auto values_a = at::randn({1});
|
||||
|
||||
std::vector<IValue> args0{a, indicies_a, values_a, false};
|
||||
testStaticRuntime(index_put_str, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Item) {
|
||||
const auto item_str = R"JIT(
|
||||
def forward(self, a: Tensor):
|
||||
return torch.item(a)
|
||||
)JIT";
|
||||
|
||||
auto a = at::randn({1});
|
||||
|
||||
std::vector<IValue> args0{a};
|
||||
testStaticRuntime(item_str, args0);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Tensor_Split) {
|
||||
const auto tensor_split_str1 = R"JIT(
|
||||
def forward(self, a: Tensor, sections: int, dim: int):
|
||||
return torch.tensor_split(a, sections, dim)
|
||||
)JIT";
|
||||
std::vector<IValue> args1{at::randn({8}), 3, 0};
|
||||
|
||||
const auto tensor_split_str2 = R"JIT(
|
||||
def forward(self, a: Tensor, sections: Tensor, dim: int):
|
||||
return torch.tensor_split(a, sections, dim)
|
||||
)JIT";
|
||||
std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
|
||||
|
||||
const auto tensor_split_str3 = R"JIT(
|
||||
def forward(self, a: Tensor, indicies: List[int], dim: int):
|
||||
return torch.tensor_split(a, indicies, dim)
|
||||
)JIT";
|
||||
std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
|
||||
|
||||
testStaticRuntime(tensor_split_str1, args1);
|
||||
testStaticRuntime(tensor_split_str2, args2);
|
||||
testStaticRuntime(tensor_split_str3, args3);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Cat) {
|
||||
const std::string cat_script = R"IR(
|
||||
graph(%a: Tensor, %b: Tensor, %dim: int):
|
||||
|
@ -200,6 +200,30 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::index_put,
|
||||
aten_index_put,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto& indices = p_node->Input(1).toOptionalTensorList();
|
||||
const auto& values = p_node->Input(2).toTensor();
|
||||
const auto accumulate = p_node->Input(3).toBool();
|
||||
p_node->Output(0) =
|
||||
at::native::index_put(self, indices, values, accumulate);
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::item,
|
||||
aten_item,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
p_node->Output(0) = at::native::item(self);
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
prim::GetAttr,
|
||||
prim_GetAttr,
|
||||
@ -694,6 +718,40 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) {
|
||||
return [](ProcessedNode* pnode) {
|
||||
const auto& a = pnode->Input(0).toTensor();
|
||||
const auto& b = pnode->Input(1).toIntVector();
|
||||
const auto c = pnode->Input(2).toInt();
|
||||
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
||||
};
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema(
|
||||
"tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) {
|
||||
return [](ProcessedNode* pnode) {
|
||||
const auto& a = pnode->Input(0).toTensor();
|
||||
const auto b = pnode->Input(1).toInt();
|
||||
const auto c = pnode->Input(2).toInt();
|
||||
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
||||
};
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema(
|
||||
"tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) {
|
||||
return [](ProcessedNode* pnode) {
|
||||
const auto& a = pnode->Input(0).toTensor();
|
||||
const auto& b = pnode->Input(1).toTensor();
|
||||
const auto c = pnode->Input(2).toInt();
|
||||
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
||||
};
|
||||
}
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::Int,
|
||||
aten_Int,
|
||||
|
Reference in New Issue
Block a user