[SR] Codegen for aten::clamp (#76340)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76340

NNC kernel for `clamp` scalar case
ghstack-source-id: 155466507

Reviewed By: navahgar, huiguoo

Differential Revision: D35904019

fbshipit-source-id: e4115757f7e2cbdf364b88be3f599dfc3028750f
(cherry picked from commit bdc4b918bc5a14490f46c79793f764b28c18388f)
This commit is contained in:
Mike Iovine
2022-05-04 15:56:47 -07:00
committed by PyTorch MergeBot
parent c59d5f17d9
commit cac2733af1
4 changed files with 73 additions and 2 deletions

View File

@ -174,6 +174,44 @@ TEST(StaticRuntime, Clamp) {
testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
}
TEST(StaticRuntime, ClampMinOnly) {
const auto src = R"JIT(
def forward(self, inp: Tensor, min: float):
a = torch.clamp(inp, min, None).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
testStaticRuntime(src, {a, 0.5});
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
}
TEST(StaticRuntime, ClampMaxOnly) {
const auto src = R"JIT(
def forward(self, inp: Tensor, max: float):
a = torch.clamp(inp, None, max).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
testStaticRuntime(src, {a, 0.5});
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
}
TEST(StaticRuntime, ClampIntTensor) {
const auto src = R"JIT(
def forward(self, inp: Tensor, min: float, max: float):
a = torch.clamp(inp, min, max).clone()
return (a)
)JIT";
auto a = at::randint(0, 20, {2, 3});
auto b = at::randint(0, 20, {4, 3, 2});
auto min = 5.0f;
auto max = 5.0f;
testStaticRuntime(src, {a, min, max});
testStaticRuntime(src, {a, min, max}, {b, min, max});
}
TEST(StaticRuntime, LenWithTuple) {
const auto src = R"IR(
graph(%input : int[]):