[Static Runtime] Support clamp.Tensor (#58191)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58191

There are two clamp overloads: clamp.Scalar and clamp.Tensor. SR needs to support both or has checks in place to avoid runtime errors. Supporting both is not too hard so here we are.

Reviewed By: edvgha

Differential Revision: D28371949

fbshipit-source-id: 0ec6b8a0b8c6277e50d8e51e4e7a45aa62211e22
This commit is contained in:
Hao Lu
2021-05-13 17:46:00 -07:00
committed by Facebook GitHub Bot
parent 1f3807ce5d
commit 993a35a8cb
3 changed files with 33 additions and 4 deletions

View File

@ -337,3 +337,15 @@ const std::string repeat = R"JIT(
def forward(self, a: Tensor, repeats: List[int]):
return torch.repeat(a, repeats)
)JIT";
const auto clamp_script_1 = R"JIT(
def forward(self, inp: Tensor, min: int, max: int):
a = torch.clamp(inp, min, max)
return (a)
)JIT";
const auto clamp_script_2 = R"JIT(
def forward(self, inp: Tensor, min: Tensor, max: Tensor):
a = torch.clamp(inp, min, max)
return (a)
)JIT";

View File

@ -141,6 +141,15 @@ TEST(StaticRuntime, Clone) {
testStaticRuntime(clone_script_1, args_1);
}
TEST(StaticRuntime, Clamp) {
auto a = at::randn({2, 3});
auto max_t = at::full_like(a, 1);
auto min_t = at::full_like(a, -1);
testStaticRuntime(clamp_script_1, {a, -1, 1});
testStaticRuntime(clamp_script_2, {a, min_t, max_t});
}
TEST(StaticRuntime, Logit) {
auto a = at::ones({2, 3});
double b = 1e-6;

View File

@ -307,7 +307,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
};
});
// TODO: support
// clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
// clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
@ -316,14 +316,22 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
}
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
const auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::clamp_out(in0_t, in1_s, in2_s, out_t);
if (p_node->Input(1).isTensor()) {
auto in1_t = p_node->Input(1).toOptional<at::Tensor>();
auto in2_t = p_node->Input(2).toOptional<at::Tensor>();
at::native::clamp_out(in0_t, in1_t, in2_t, out_t);
} else {
auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
at::native::clamp_out(in0_t, in1_s, in2_s, out_t);
}
};
});