From e9327ed2cee458229dc3425fc38525dc81219e20 Mon Sep 17 00:00:00 2001 From: Rui Zhu Date: Mon, 27 Sep 2021 18:38:13 -0700 Subject: [PATCH] Add nn.function.hardtanh in acc_tracer (#65639) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65639 This op is used by mobilenet v2. Test Plan: buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh buck test glow/fb/fx/acc_tracer:test_acc_shape_inference -- hardtanh buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh Reviewed By: yinghai Differential Revision: D31184297 fbshipit-source-id: 5a04319f6d16fb930372442616e27211107ecc67 --- torch/fx/experimental/fx_acc/acc_ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 7f4651065922..2a77fd39a07b 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -410,6 +410,18 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module): """ return node.kwargs["input"] +@register_acc_op_mapping( + op_and_target=("call_function", nn.functional.hardtanh), + arg_replacement_tuples=[ + ("input", "input"), + ("min_val", "left"), + ("max_val", "right"), + ], +) +@register_acc_op +def hardtanh(*, input, left, right): + return nn.functional.hardtanh(input, min_val=left, max_val=right) + @register_acc_op_mapping( op_and_target=("call_function", nn.functional.hardsigmoid)) @register_acc_op