mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SR] Codegen for aten::clamp (#76340)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76340 NNC kernel for `clamp` scalar case ghstack-source-id: 155466507 Reviewed By: navahgar, huiguoo Differential Revision: D35904019 fbshipit-source-id: e4115757f7e2cbdf364b88be3f599dfc3028750f (cherry picked from commit bdc4b918bc5a14490f46c79793f764b28c18388f)
This commit is contained in:
committed by
PyTorch MergeBot
parent
c59d5f17d9
commit
cac2733af1
@ -174,6 +174,44 @@ TEST(StaticRuntime, Clamp) {
|
||||
testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampMinOnly) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, inp: Tensor, min: float):
|
||||
a = torch.clamp(inp, min, None).clone()
|
||||
return (a)
|
||||
)JIT";
|
||||
auto a = at::randn({2, 3});
|
||||
auto b = at::randn({4, 3, 2});
|
||||
testStaticRuntime(src, {a, 0.5});
|
||||
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampMaxOnly) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, inp: Tensor, max: float):
|
||||
a = torch.clamp(inp, None, max).clone()
|
||||
return (a)
|
||||
)JIT";
|
||||
auto a = at::randn({2, 3});
|
||||
auto b = at::randn({4, 3, 2});
|
||||
testStaticRuntime(src, {a, 0.5});
|
||||
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampIntTensor) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, inp: Tensor, min: float, max: float):
|
||||
a = torch.clamp(inp, min, max).clone()
|
||||
return (a)
|
||||
)JIT";
|
||||
auto a = at::randint(0, 20, {2, 3});
|
||||
auto b = at::randint(0, 20, {4, 3, 2});
|
||||
auto min = 5.0f;
|
||||
auto max = 5.0f;
|
||||
testStaticRuntime(src, {a, min, max});
|
||||
testStaticRuntime(src, {a, min, max}, {b, min, max});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, LenWithTuple) {
|
||||
const auto src = R"IR(
|
||||
graph(%input : int[]):
|
||||
|
Reference in New Issue
Block a user