mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[static-runtime] out variant for aten::max (#78271)
Summary: Previously the op was auto-generated but it only covered the pointwise overload of aten::max. This adds support for reduction, overall and along a dim Test Plan: Added a unit test Differential Revision: D36656378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78271 Approved by: https://github.com/mikeiovine
This commit is contained in:
committed by
PyTorch MergeBot
parent
ac031e1326
commit
2679755bdc
@ -94,6 +94,43 @@ TEST(StaticRuntime, UnaryOps) {
|
||||
testStaticRuntime(aten_sum_1_true, args, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Max) {
|
||||
auto src_max_reduce = R"JIT(
|
||||
def forward(self, input):
|
||||
return torch.max(input).clone()
|
||||
)JIT";
|
||||
|
||||
auto src_max_dim = R"JIT(
|
||||
def forward(self, input, dim: int):
|
||||
values, indices = torch.max(input, dim)
|
||||
return values.clone(), indices.clone()
|
||||
)JIT";
|
||||
|
||||
auto src_max_dim_keepdim = R"JIT(
|
||||
def forward(self, input, dim: int):
|
||||
values, indices = torch.max(input, dim, keepdim=True)
|
||||
return values.clone(), indices.clone()
|
||||
)JIT";
|
||||
|
||||
auto src_max_pointwise = R"JIT(
|
||||
def forward(self, input, other):
|
||||
return torch.max(input, other).clone()
|
||||
)JIT";
|
||||
|
||||
auto input = at::randn({2, 3, 2});
|
||||
auto input_other = at::randn({2, 3, 2});
|
||||
auto large_input = at::randn({8, 9, 10});
|
||||
auto large_input_other = at::randn({8, 9, 10});
|
||||
|
||||
testStaticRuntime(src_max_reduce, {input});
|
||||
testStaticRuntime(src_max_dim, {input, 1});
|
||||
testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});
|
||||
testStaticRuntime(src_max_dim_keepdim, {input, 0});
|
||||
testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});
|
||||
testStaticRuntime(src_max_pointwise, {input, input_other});
|
||||
testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Sigmoid) {
|
||||
const auto sigmoid_script = R"JIT(
|
||||
def forward(self, inp: Tensor):
|
||||
|
@ -3167,25 +3167,6 @@ REGISTER_OPERATOR_FUNCTOR(
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"aten::max.other(Tensor self, Tensor other) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto& other = p_node->Input(1).toTensor();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::native::max(self, other);
|
||||
return;
|
||||
}
|
||||
auto& out = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out);
|
||||
at::native::max_out(self, other, out);
|
||||
};
|
||||
}
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(
|
||||
aten::minimum,
|
||||
aten_minimum,
|
||||
|
@ -1721,6 +1721,61 @@ REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::max, aten_max, [](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"aten::max.other(Tensor self, Tensor other) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto& other = p_node->Input(1).toTensor();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::native::max(self, other);
|
||||
return;
|
||||
}
|
||||
auto& out = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out);
|
||||
at::native::max_out(self, other, out);
|
||||
};
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema(
|
||||
"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
auto dim = p_node->Input(1).toInt();
|
||||
const auto keepdim = p_node->Input(2).toBool();
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = create_empty_from(self);
|
||||
}
|
||||
|
||||
if (p_node->Output(1).isNone()) {
|
||||
p_node->Output(1) = create_empty_from(self, at::kLong);
|
||||
}
|
||||
|
||||
auto& values = p_node->Output(0).toTensor();
|
||||
auto& indices = p_node->Output(1).toTensor();
|
||||
fastResizeToZero(values);
|
||||
fastResizeToZero(indices);
|
||||
at::cpu::max_out(values, indices, self, dim, keepdim);
|
||||
};
|
||||
}
|
||||
|
||||
if (n->matches(torch::schema("aten::max(Tensor self) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = create_empty_from(self);
|
||||
}
|
||||
auto& value = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(value);
|
||||
at::cpu::amax_out(value, self);
|
||||
};
|
||||
}
|
||||
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::sign, aten_sign, [](Node* n) -> SROperator {
|
||||
if (!n->matches(torch::schema("aten::sign.Tensor(Tensor input) -> Tensor"))) {
|
||||
LogAndDumpSchema(n);
|
||||
|
Reference in New Issue
Block a user