mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SR] Codegen for aten::clamp (#76340)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76340 NNC kernel for `clamp` scalar case ghstack-source-id: 155466507 Reviewed By: navahgar, huiguoo Differential Revision: D35904019 fbshipit-source-id: e4115757f7e2cbdf364b88be3f599dfc3028750f (cherry picked from commit bdc4b918bc5a14490f46c79793f764b28c18388f)
This commit is contained in:
committed by
PyTorch MergeBot
parent
c59d5f17d9
commit
cac2733af1
@ -174,6 +174,44 @@ TEST(StaticRuntime, Clamp) {
|
||||
testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampMinOnly) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, inp: Tensor, min: float):
|
||||
a = torch.clamp(inp, min, None).clone()
|
||||
return (a)
|
||||
)JIT";
|
||||
auto a = at::randn({2, 3});
|
||||
auto b = at::randn({4, 3, 2});
|
||||
testStaticRuntime(src, {a, 0.5});
|
||||
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampMaxOnly) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, inp: Tensor, max: float):
|
||||
a = torch.clamp(inp, None, max).clone()
|
||||
return (a)
|
||||
)JIT";
|
||||
auto a = at::randn({2, 3});
|
||||
auto b = at::randn({4, 3, 2});
|
||||
testStaticRuntime(src, {a, 0.5});
|
||||
testStaticRuntime(src, {a, 0.5}, {b, 0.25});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, ClampIntTensor) {
|
||||
const auto src = R"JIT(
|
||||
def forward(self, inp: Tensor, min: float, max: float):
|
||||
a = torch.clamp(inp, min, max).clone()
|
||||
return (a)
|
||||
)JIT";
|
||||
auto a = at::randint(0, 20, {2, 3});
|
||||
auto b = at::randint(0, 20, {4, 3, 2});
|
||||
auto min = 5.0f;
|
||||
auto max = 5.0f;
|
||||
testStaticRuntime(src, {a, min, max});
|
||||
testStaticRuntime(src, {a, min, max}, {b, min, max});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, LenWithTuple) {
|
||||
const auto src = R"IR(
|
||||
graph(%input : int[]):
|
||||
|
@ -607,7 +607,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator {
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
|
||||
if (n->matches(torch::schema(
|
||||
"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
return [te = createClamp()](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = create_empty_from(in0_t);
|
||||
@ -616,7 +616,17 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator {
|
||||
fastResizeToZero(out_t);
|
||||
auto in1_s = p_node->Input(1).toOptional<at::Scalar>();
|
||||
auto in2_s = p_node->Input(2).toOptional<at::Scalar>();
|
||||
at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s);
|
||||
if (!te->checkInput<float>(in0_t)) {
|
||||
at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s);
|
||||
return;
|
||||
}
|
||||
at::native::resize_(out_t, in0_t.sizes(), c10::nullopt);
|
||||
auto output_size = in0_t.numel();
|
||||
auto min = in1_s.has_value() ? in1_s->toFloat()
|
||||
: std::numeric_limits<float>::lowest();
|
||||
auto max = in2_s.has_value() ? in2_s->toFloat()
|
||||
: std::numeric_limits<float>::max();
|
||||
te->call({out_t.data_ptr(), in0_t.data_ptr(), &min, &max, &output_size});
|
||||
};
|
||||
}
|
||||
if (n->matches(
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
||||
|
||||
namespace torch {
|
||||
@ -190,6 +191,27 @@ std::shared_ptr<TEWrapper> createSigmoid() {
|
||||
return wrap;
|
||||
}
|
||||
|
||||
std::shared_ptr<TEWrapper> createClamp() {
|
||||
static auto clamp_symbol = c10::Symbol::fromQualString("aten::clamp");
|
||||
auto wrap = lookupNNCCache(clamp_symbol);
|
||||
if (wrap) {
|
||||
return wrap;
|
||||
}
|
||||
wrap = std::make_shared<TEWrapper>();
|
||||
auto N = VarHandle("N", kInt);
|
||||
auto min_handle = VarHandle("min", kFloat);
|
||||
auto max_handle = VarHandle("max", kFloat);
|
||||
|
||||
BufHandle A("A", {N}, kFloat);
|
||||
Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
|
||||
auto a = A.load(i);
|
||||
return tensorexpr::clamp(min_handle, max_handle, a);
|
||||
});
|
||||
wrap = wrapTECompute(wrap, result, {A, min_handle, max_handle, N});
|
||||
updateNNCCache(clamp_symbol, wrap);
|
||||
return wrap;
|
||||
}
|
||||
|
||||
std::shared_ptr<TEWrapper> createSignedLog1p() {
|
||||
static auto signed_log1p_symbol =
|
||||
c10::Symbol::fromQualString("static_runtime::signed_log1p");
|
||||
|
@ -38,6 +38,7 @@ std::shared_ptr<TEWrapper> createRelu();
|
||||
std::shared_ptr<TEWrapper> createTanh();
|
||||
std::shared_ptr<TEWrapper> createSigmoid();
|
||||
std::shared_ptr<TEWrapper> createSignedLog1p();
|
||||
std::shared_ptr<TEWrapper> createClamp();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user