mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nn.function.hardtanh in acc_tracer (#65639)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65639 This op is used by mobilenet v2. Test Plan: buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh buck test glow/fb/fx/acc_tracer:test_acc_shape_inference -- hardtanh buck test glow/fb/fx/oss_acc_tracer:test_acc_tracer -- test_hardtanh Reviewed By: yinghai Differential Revision: D31184297 fbshipit-source-id: 5a04319f6d16fb930372442616e27211107ecc67
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6a6ee92e36
commit
e9327ed2ce
@ -410,6 +410,18 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
|
||||
"""
|
||||
return node.kwargs["input"]
|
||||
|
||||
@register_acc_op_mapping(
|
||||
op_and_target=("call_function", nn.functional.hardtanh),
|
||||
arg_replacement_tuples=[
|
||||
("input", "input"),
|
||||
("min_val", "left"),
|
||||
("max_val", "right"),
|
||||
],
|
||||
)
|
||||
@register_acc_op
|
||||
def hardtanh(*, input, left, right):
|
||||
return nn.functional.hardtanh(input, min_val=left, max_val=right)
|
||||
|
||||
@register_acc_op_mapping(
|
||||
op_and_target=("call_function", nn.functional.hardsigmoid))
|
||||
@register_acc_op
|
||||
|
Reference in New Issue
Block a user