[static-runtime] out variant for aten::max (#78271)

Summary: Previously the op was auto-generated but it only covered the pointwise overload of aten::max. This adds support for reduction, overall and along a dim

Test Plan: Added a unit test

Differential Revision: D36656378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78271
Approved by: https://github.com/mikeiovine
This commit is contained in:
Max Podkorytov
2022-05-26 23:29:27 +00:00
committed by PyTorch MergeBot
parent ac031e1326
commit 2679755bdc
3 changed files with 92 additions and 19 deletions

View File

@ -94,6 +94,43 @@ TEST(StaticRuntime, UnaryOps) {
testStaticRuntime(aten_sum_1_true, args, args2);
}
TEST(StaticRuntime, Max) {
auto src_max_reduce = R"JIT(
def forward(self, input):
return torch.max(input).clone()
)JIT";
auto src_max_dim = R"JIT(
def forward(self, input, dim: int):
values, indices = torch.max(input, dim)
return values.clone(), indices.clone()
)JIT";
auto src_max_dim_keepdim = R"JIT(
def forward(self, input, dim: int):
values, indices = torch.max(input, dim, keepdim=True)
return values.clone(), indices.clone()
)JIT";
auto src_max_pointwise = R"JIT(
def forward(self, input, other):
return torch.max(input, other).clone()
)JIT";
auto input = at::randn({2, 3, 2});
auto input_other = at::randn({2, 3, 2});
auto large_input = at::randn({8, 9, 10});
auto large_input_other = at::randn({8, 9, 10});
testStaticRuntime(src_max_reduce, {input});
testStaticRuntime(src_max_dim, {input, 1});
testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});
testStaticRuntime(src_max_dim_keepdim, {input, 0});
testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});
testStaticRuntime(src_max_pointwise, {input, input_other});
testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});
}
TEST(StaticRuntime, Sigmoid) {
const auto sigmoid_script = R"JIT(
def forward(self, inp: Tensor):

View File

@ -3167,25 +3167,6 @@ REGISTER_OPERATOR_FUNCTOR(
return nullptr;
});
REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::max.other(Tensor self, Tensor other) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto& other = p_node->Input(1).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::native::max(self, other);
return;
}
auto& out = p_node->Output(0).toTensor();
fastResizeToZero(out);
at::native::max_out(self, other, out);
};
}
LogAndDumpSchema(n);
return nullptr;
});
REGISTER_OPERATOR_FUNCTOR(
aten::minimum,
aten_minimum,

View File

@ -1721,6 +1721,61 @@ REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
};
});
REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::max.other(Tensor self, Tensor other) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto& other = p_node->Input(1).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::native::max(self, other);
return;
}
auto& out = p_node->Output(0).toTensor();
fastResizeToZero(out);
at::native::max_out(self, other, out);
};
}
if (n->matches(torch::schema(
"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
auto dim = p_node->Input(1).toInt();
const auto keepdim = p_node->Input(2).toBool();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(self);
}
if (p_node->Output(1).isNone()) {
p_node->Output(1) = create_empty_from(self, at::kLong);
}
auto& values = p_node->Output(0).toTensor();
auto& indices = p_node->Output(1).toTensor();
fastResizeToZero(values);
fastResizeToZero(indices);
at::cpu::max_out(values, indices, self, dim, keepdim);
};
}
if (n->matches(torch::schema("aten::max(Tensor self) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(self);
}
auto& value = p_node->Output(0).toTensor();
fastResizeToZero(value);
at::cpu::amax_out(value, self);
};
}
LogAndDumpSchema(n);
return nullptr;
});
REGISTER_OPERATOR_FUNCTOR(aten::sign, aten_sign, [](Node* n) -> SROperator {
if (!n->matches(torch::schema("aten::sign.Tensor(Tensor input) -> Tensor"))) {
LogAndDumpSchema(n);