[SR] Codegen for aten::clamp (#76340)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76340

NNC kernel for `clamp` scalar case
ghstack-source-id: 155466507

Reviewed By: navahgar, huiguoo

Differential Revision: D35904019

fbshipit-source-id: e4115757f7e2cbdf364b88be3f599dfc3028750f
(cherry picked from commit bdc4b918bc5a14490f46c79793f764b28c18388f)
This commit is contained in:
Mike Iovine
2022-05-04 15:56:47 -07:00
committed by PyTorch MergeBot
parent c59d5f17d9
commit cac2733af1
4 changed files with 73 additions and 2 deletions

View File

@ -174,6 +174,44 @@ TEST(StaticRuntime, Clamp) {
testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
}
TEST(StaticRuntime, ClampMinOnly) {
const auto src = R"JIT(
def forward(self, inp: Tensor, min: float):
a = torch.clamp(inp, min, None).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
testStaticRuntime(src, {a, 0.5});
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
}
TEST(StaticRuntime, ClampMaxOnly) {
const auto src = R"JIT(
def forward(self, inp: Tensor, max: float):
a = torch.clamp(inp, None, max).clone()
return (a)
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({4, 3, 2});
testStaticRuntime(src, {a, 0.5});
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
}
TEST(StaticRuntime, ClampIntTensor) {
const auto src = R"JIT(
def forward(self, inp: Tensor, min: float, max: float):
a = torch.clamp(inp, min, max).clone()
return (a)
)JIT";
auto a = at::randint(0, 20, {2, 3});
auto b = at::randint(0, 20, {4, 3, 2});
auto min = 5.0f;
auto max = 5.0f;
testStaticRuntime(src, {a, min, max});
testStaticRuntime(src, {a, min, max}, {b, min, max});
}
TEST(StaticRuntime, LenWithTuple) {
const auto src = R"IR(
graph(%input : int[]):

View File

@ -607,7 +607,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
if (n->matches(torch::schema(
"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"))) {
return [](ProcessedNode* p_node) {
return [te = createClamp()](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_t);
@ -616,7 +616,17 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
fastResizeToZero(out_t);
auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s);
if (!te->checkInput<float>(in0_t)) {
at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s);
return;
}
at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
auto output_size = in0_t.numel();
auto min = in1_s.has_value() ? in1_s->toFloat()
: std::numeric_limits<float>::lowest();
auto max = in2_s.has_value() ? in2_s->toFloat()
: std::numeric_limits<float>::max();
te->call({out_t.data_ptr(), in0_t.data_ptr(), &min, &max, &output_size});
};
}
if (n->matches(

View File

@ -4,6 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
namespace torch {
@ -190,6 +191,27 @@ std::shared_ptr<TEWrapper> createSigmoid() {
return wrap;
}
std::shared_ptr<TEWrapper> createClamp() {
static auto clamp_symbol = c10::Symbol::fromQualString("aten::clamp");
auto wrap = lookupNNCCache(clamp_symbol);
if (wrap) {
return wrap;
}
wrap = std::make_shared<TEWrapper>();
auto N = VarHandle("N", kInt);
auto min_handle = VarHandle("min", kFloat);
auto max_handle = VarHandle("max", kFloat);
BufHandle A("A", {N}, kFloat);
Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
auto a = A.load(i);
return tensorexpr::clamp(min_handle, max_handle, a);
});
wrap = wrapTECompute(wrap, result, {A, min_handle, max_handle, N});
updateNNCCache(clamp_symbol, wrap);
return wrap;
}
std::shared_ptr<TEWrapper> createSignedLog1p() {
static auto signed_log1p_symbol =
c10::Symbol::fromQualString("static_runtime::signed_log1p");

View File

@ -38,6 +38,7 @@ std::shared_ptr<TEWrapper> createRelu();
std::shared_ptr<TEWrapper> createTanh();
std::shared_ptr<TEWrapper> createSigmoid();
std::shared_ptr<TEWrapper> createSignedLog1p();
std::shared_ptr<TEWrapper> createClamp();
} // namespace jit
} // namespace torch