mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[static runtime] Add native ops: aten::index_put, aten::item, aten::tensor_split (#79065)
Summary: This adds the pytorch operators that are currently missing in non-ads models from c2->pt mitigation: aten::index_put, aten::item, aten::tensor_split Test Plan: buck run mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest Differential Revision: D36984961 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79065 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bc1b9e039
commit
8d7fcfa8f1
@ -2482,6 +2482,56 @@ TEST(StaticRuntime, LinalgNorm_StringOrd) {
|
|||||||
testStaticRuntime(linalg_norm_ord_str, args0, args1);
|
testStaticRuntime(linalg_norm_ord_str, args0, args1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StaticRuntime, Index_Put) {
|
||||||
|
const auto index_put_str = R"JIT(
|
||||||
|
def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
|
||||||
|
return torch.index_put(a, indices, values, accumulate).clone()
|
||||||
|
)JIT";
|
||||||
|
|
||||||
|
auto a = at::randn({2});
|
||||||
|
auto indicies_a = std::make_tuple(torch::tensor({0}, at::kLong));
|
||||||
|
auto values_a = at::randn({1});
|
||||||
|
|
||||||
|
std::vector<IValue> args0{a, indicies_a, values_a, false};
|
||||||
|
testStaticRuntime(index_put_str, args0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticRuntime, Item) {
|
||||||
|
const auto item_str = R"JIT(
|
||||||
|
def forward(self, a: Tensor):
|
||||||
|
return torch.item(a)
|
||||||
|
)JIT";
|
||||||
|
|
||||||
|
auto a = at::randn({1});
|
||||||
|
|
||||||
|
std::vector<IValue> args0{a};
|
||||||
|
testStaticRuntime(item_str, args0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticRuntime, Tensor_Split) {
|
||||||
|
const auto tensor_split_str1 = R"JIT(
|
||||||
|
def forward(self, a: Tensor, sections: int, dim: int):
|
||||||
|
return torch.tensor_split(a, sections, dim)
|
||||||
|
)JIT";
|
||||||
|
std::vector<IValue> args1{at::randn({8}), 3, 0};
|
||||||
|
|
||||||
|
const auto tensor_split_str2 = R"JIT(
|
||||||
|
def forward(self, a: Tensor, sections: Tensor, dim: int):
|
||||||
|
return torch.tensor_split(a, sections, dim)
|
||||||
|
)JIT";
|
||||||
|
std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
|
||||||
|
|
||||||
|
const auto tensor_split_str3 = R"JIT(
|
||||||
|
def forward(self, a: Tensor, indicies: List[int], dim: int):
|
||||||
|
return torch.tensor_split(a, indicies, dim)
|
||||||
|
)JIT";
|
||||||
|
std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
|
||||||
|
|
||||||
|
testStaticRuntime(tensor_split_str1, args1);
|
||||||
|
testStaticRuntime(tensor_split_str2, args2);
|
||||||
|
testStaticRuntime(tensor_split_str3, args3);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StaticRuntime, Cat) {
|
TEST(StaticRuntime, Cat) {
|
||||||
const std::string cat_script = R"IR(
|
const std::string cat_script = R"IR(
|
||||||
graph(%a: Tensor, %b: Tensor, %dim: int):
|
graph(%a: Tensor, %b: Tensor, %dim: int):
|
||||||
|
@ -200,6 +200,30 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||||
|
aten::index_put,
|
||||||
|
aten_index_put,
|
||||||
|
[](Node* n) -> SROperator {
|
||||||
|
return [](ProcessedNode* p_node) {
|
||||||
|
const auto& self = p_node->Input(0).toTensor();
|
||||||
|
const auto& indices = p_node->Input(1).toOptionalTensorList();
|
||||||
|
const auto& values = p_node->Input(2).toTensor();
|
||||||
|
const auto accumulate = p_node->Input(3).toBool();
|
||||||
|
p_node->Output(0) =
|
||||||
|
at::native::index_put(self, indices, values, accumulate);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||||
|
aten::item,
|
||||||
|
aten_item,
|
||||||
|
[](Node* n) -> SROperator {
|
||||||
|
return [](ProcessedNode* p_node) {
|
||||||
|
const auto& self = p_node->Input(0).toTensor();
|
||||||
|
p_node->Output(0) = at::native::item(self);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||||
prim::GetAttr,
|
prim::GetAttr,
|
||||||
prim_GetAttr,
|
prim_GetAttr,
|
||||||
@ -694,6 +718,40 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* n) -> SROperator {
|
||||||
|
if (n->matches(torch::schema(
|
||||||
|
"tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[]"))) {
|
||||||
|
return [](ProcessedNode* pnode) {
|
||||||
|
const auto& a = pnode->Input(0).toTensor();
|
||||||
|
const auto& b = pnode->Input(1).toIntVector();
|
||||||
|
const auto c = pnode->Input(2).toInt();
|
||||||
|
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n->matches(torch::schema(
|
||||||
|
"tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) {
|
||||||
|
return [](ProcessedNode* pnode) {
|
||||||
|
const auto& a = pnode->Input(0).toTensor();
|
||||||
|
const auto b = pnode->Input(1).toInt();
|
||||||
|
const auto c = pnode->Input(2).toInt();
|
||||||
|
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n->matches(torch::schema(
|
||||||
|
"tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]"))) {
|
||||||
|
return [](ProcessedNode* pnode) {
|
||||||
|
const auto& a = pnode->Input(0).toTensor();
|
||||||
|
const auto& b = pnode->Input(1).toTensor();
|
||||||
|
const auto c = pnode->Input(2).toInt();
|
||||||
|
pnode->Output(0) = at::native::tensor_split(a, b, c);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
LogAndDumpSchema(n);
|
||||||
|
return nullptr;
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||||
aten::Int,
|
aten::Int,
|
||||||
aten_Int,
|
aten_Int,
|
||||||
|
Reference in New Issue
Block a user