mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] Support clamp.Tensor (#58191)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58191 There are two clamp overloads: clamp.Scalar and clamp.Tensor. SR needs to support both or has checks in place to avoid runtime errors. Supporting both is not too hard so here we are. Reviewed By: edvgha Differential Revision: D28371949 fbshipit-source-id: 0ec6b8a0b8c6277e50d8e51e4e7a45aa62211e22
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1f3807ce5d
commit
993a35a8cb
@ -337,3 +337,15 @@ const std::string repeat = R"JIT(
|
||||
def forward(self, a: Tensor, repeats: List[int]):
|
||||
return torch.repeat(a, repeats)
|
||||
)JIT";
|
||||
|
||||
const auto clamp_script_1 = R"JIT(
|
||||
def forward(self, inp: Tensor, min: int, max: int):
|
||||
a = torch.clamp(inp, min, max)
|
||||
return (a)
|
||||
)JIT";
|
||||
|
||||
const auto clamp_script_2 = R"JIT(
|
||||
def forward(self, inp: Tensor, min: Tensor, max: Tensor):
|
||||
a = torch.clamp(inp, min, max)
|
||||
return (a)
|
||||
)JIT";
|
||||
|
@ -141,6 +141,15 @@ TEST(StaticRuntime, Clone) {
|
||||
testStaticRuntime(clone_script_1, args_1);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Clamp) {
|
||||
auto a = at::randn({2, 3});
|
||||
auto max_t = at::full_like(a, 1);
|
||||
auto min_t = at::full_like(a, -1);
|
||||
|
||||
testStaticRuntime(clamp_script_1, {a, -1, 1});
|
||||
testStaticRuntime(clamp_script_2, {a, min_t, max_t});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, Logit) {
|
||||
auto a = at::ones({2, 3});
|
||||
double b = 1e-6;
|
||||
|
@ -307,7 +307,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
|
||||
};
|
||||
});
|
||||
|
||||
// TODO: support
|
||||
// clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
|
||||
// clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
|
||||
@ -316,14 +316,22 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
const auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
|
||||
const auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = create_empty_from(in0_t);
|
||||
}
|
||||
auto& out_t = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out_t);
|
||||
at::native::clamp_out(in0_t, in1_s, in2_s, out_t);
|
||||
|
||||
if (p_node->Input(1).isTensor()) {
|
||||
auto in1_t = p_node->Input(1).toOptional<at::Tensor>();
|
||||
auto in2_t = p_node->Input(2).toOptional<at::Tensor>();
|
||||
at::native::clamp_out(in0_t, in1_t, in2_t, out_t);
|
||||
} else {
|
||||
auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
|
||||
auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
|
||||
at::native::clamp_out(in0_t, in1_s, in2_s, out_t);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user