Add nn.function.hardtanh in acc_tracer (#65639)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65639

This op is used by mobilenet v2.

Test Plan:
buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh
buck test glow/fb/fx/acc_tracer:test_acc_shape_inference -- hardtanh
buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh

Reviewed By: yinghai

Differential Revision: D31184297

fbshipit-source-id: 5a04319f6d16fb930372442616e27211107ecc67
This commit is contained in:
Rui Zhu
2021-09-27 18:38:13 -07:00
committed by Facebook GitHub Bot
parent 6a6ee92e36
commit e9327ed2ce

View File

@ -410,6 +410,18 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
"""
return node.kwargs["input"]
@register_acc_op_mapping(
op_and_target=("call_function", nn.functional.hardtanh),
arg_replacement_tuples=[
("input", "input"),
("min_val", "left"),
("max_val", "right"),
],
)
@register_acc_op
def hardtanh(*, input, left, right):
return nn.functional.hardtanh(input, min_val=left, max_val=right)
@register_acc_op_mapping(
op_and_target=("call_function", nn.functional.hardsigmoid))
@register_acc_op