[fx2trt] break down div (#71172)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71172

Break down div to smaller ops to make those div ops look like all other elementwise ops.

Use operator div ops instead of torch div if possible to avoid converting literal numbers to torch tensor (like in the following).
```
a = 1
b = 2

// `c` would be 0.5
c = a / b

// `c` would be torch.tensor([0.5])
c = torch.div(a, b)
```

The problem we saw on shufflenet is that there's size op followed by a div op which results in int64 tensors in acc traced graph (acc tracer turns operator.div to acc_ops.div which uses torch.div). And trt splitter splits out the reshape op that consumes the div op because we have a rule to split out ops that takes in int64 tensors as inputs.

Test Plan: Unit tests.

Reviewed By: wushirong

Differential Revision: D33482231

fbshipit-source-id: 508a171520c4e5b4188cfc5c30c1370ba9db1c55
This commit is contained in:
Shiyan Deng
2022-01-12 09:41:14 -08:00
committed by Facebook GitHub Bot
parent 6a40bb0fdf
commit 54fe2741a1
5 changed files with 90 additions and 62 deletions

View File

@ -12,13 +12,14 @@ from torch.testing._internal.common_utils import run_tests
elementwise_ops = [
((lambda x, y: x + y), acc_ops.add),
((lambda x, y: x - y), acc_ops.sub),
# Avoid dividing by 0.
((lambda x, y: x / (y + 1.0)), acc_ops.div),
((lambda x, y: x // (y + 1.0)), acc_ops.div),
((lambda x, y: torch.div(x, y + 1.0, rounding_mode="trunc")), acc_ops.div),
((lambda x, y: torch.div(x, y + 1.0, rounding_mode="floor")), acc_ops.div),
((lambda x, y: torch.div(x, y + 1.0)), acc_ops.div),
((lambda x, y: torch.floor_divide(x, y + 1.0)), acc_ops.div),
((lambda x, y: x / y), acc_ops.div),
((lambda x, y: x // y), acc_ops.floor_div),
((lambda x, y: torch.div(x, y, rounding_mode="trunc")), acc_ops.trunc_div),
((lambda x, y: torch.div(x, y, rounding_mode="floor")), acc_ops.floor_div),
((lambda x, y: torch.div(x, y)), acc_ops.div),
# torch.floor_divide rounds result toward zero, rather than -Inf.
# https://github.com/pytorch/pytorch/issues/43874
((lambda x, y: torch.floor_divide(x, y)), acc_ops.trunc_div),
((lambda x, y: x * y), acc_ops.mul),
(torch.pow, acc_ops.pow),
]
@ -36,7 +37,8 @@ class TestBinaryOpConverters(AccTestCase):
return self.orig_op(x, x)
m = TestModule(orig_op)
inputs = [torch.randn(1, 1)]
# Avoid dividing by 0.
inputs = [torch.rand(1, 1) + 1]
self.run_test(m, inputs, expected_ops={expected_op})
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])

View File

@ -1457,6 +1457,17 @@ class AccTracerTest(unittest.TestCase):
def test_torch_mul(self):
self._make_acc_op_function_test(acc_ops.mul, lambda x: torch.mul(x, 7))
def test_div(self):
self._make_acc_op_function_test(acc_ops.div, lambda x: torch.div(x, 2))
self._make_acc_op_function_test(acc_ops.div, lambda x: x / 2)
def test_floor_div(self):
self._make_acc_op_function_test(acc_ops.floor_div, lambda x: torch.div(x, 2, rounding_mode="floor"))
def test_trunc_div(self):
self._make_acc_op_function_test(acc_ops.trunc_div, lambda x: torch.div(x, 2, rounding_mode="trunc"))
self._make_acc_op_function_test(acc_ops.trunc_div, lambda x: torch.floor_divide(x, 2))
def test_view(self):
"""
Test that Tensor.view is traced correctly.
@ -1912,6 +1923,8 @@ class AccTracerTest(unittest.TestCase):
acc_ops.sub,
acc_ops.mul,
acc_ops.div,
acc_ops.floor_div,
acc_ops.trunc_div,
acc_ops.pow,
acc_ops.relu,
acc_ops.leaky_relu,

View File

@ -1161,21 +1161,33 @@ def acc_ops_div(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if kwargs["rounding_mode"] == "trunc":
inputs = kwargs["input"]
other = kwargs["other"]
return trunc_div(inputs, other, network, target, name)
elif kwargs["rounding_mode"] == "floor":
return add_binary_elementwise_layer(
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.FLOOR_DIV, target, name
)
elif kwargs["rounding_mode"] is None:
return add_binary_elementwise_layer(
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
)
else :
mode = kwargs["rounding_mode"]
raise RuntimeError(f"Div received mode {mode} that is not supported!")
return add_binary_elementwise_layer(
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
)
@tensorrt_converter(acc_ops.floor_div)
def acc_ops_floor_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return add_binary_elementwise_layer(
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.FLOOR_DIV, target, name
)
@tensorrt_converter(acc_ops.trunc_div)
def acc_ops_trunc_div(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return trunc_div(kwargs["input"], kwargs["other"], network, target, name)
@tensorrt_converter(acc_ops.mul)
@ -2106,14 +2118,18 @@ def acc_ops_cumsum(
set_layer_name(running_sum, target, f"{name}_running_sum_1")
running_sum_tensor = running_sum.get_output(0)
current_sum = add_binary_elementwise_layer(network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, "sum_1")
current_sum = add_binary_elementwise_layer(
network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, f"{name}_sum_1"
)
running_sum.set_input(1, current_sum)
running_sum = loop.add_recurrence(zero_tensor)
set_layer_name(running_sum, target, f"{name}_running_sum_2")
running_sum_tensor = running_sum.get_output(0)
current_sum = add_binary_elementwise_layer(network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, "sum_2")
current_sum = add_binary_elementwise_layer(
network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, f"{name}_sum_2"
)
running_sum.set_input(1, current_sum)
loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)

View File

@ -630,24 +630,6 @@ def mul(*, input, other):
return input * other
# Torch.floor_divide is announced to be deprecated, consider using torch.div() with 'trunc' or 'floor'
# mode instead.
# This implementation matches torch.floor_div's behavior, which for negative number the divide result
# is round toward zero, rather than -Inf.
@register_custom_acc_mapper_fn(
op_and_target=("call_function", torch.floor_divide),
arg_replacement_tuples=[
("input", "input"),
("other", "other"),
],
)
@register_custom_acc_mapper_fn(
op_and_target=("call_function", operator.floordiv),
arg_replacement_tuples=[
("input", "input"),
("other", "other"),
],
)
@register_custom_acc_mapper_fn(
op_and_target=("call_function", torch.div),
arg_replacement_tuples=[
@ -656,32 +638,47 @@ def mul(*, input, other):
("rounding_mode", "rounding_mode", this_arg_is_optional),
],
)
@register_custom_acc_mapper_fn(
op_and_target=("call_function", operator.truediv),
arg_replacement_tuples=[
("input", "input"),
("other", "other"),
],
)
def div_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
with node.graph.inserting_before(node):
div_kwargs = dict(node.kwargs)
if "rounding_mode" not in div_kwargs and node.op == "call_function":
div_kwargs["rounding_mode"] = None
if node.target is torch.floor_divide:
div_kwargs["rounding_mode"] = "trunc"
elif node.target is operator.floordiv:
div_kwargs["rounding_mode"] = "floor"
elif node.target is operator.truediv:
div_kwargs["rounding_mode"] = None
div_node = node.graph.call_function(div, kwargs=div_kwargs)
if "rounding_mode" not in div_kwargs or div_kwargs["rounding_mode"] is None:
div_node = node.graph.call_function(div, kwargs={"input": div_kwargs["input"], "other": div_kwargs["other"]})
elif div_kwargs["rounding_mode"] == "trunc":
div_node = node.graph.call_function(trunc_div, kwargs={"input": div_kwargs["input"], "other": div_kwargs["other"]})
elif div_kwargs["rounding_mode"] == "floor":
div_node = node.graph.call_function(floor_div, kwargs={"input": div_kwargs["input"], "other": div_kwargs["other"]})
else:
raise RuntimeError(f"Unhandled div rounding mode {div_kwargs['rounding_mode']}")
div_node.meta = node.meta.copy()
return div_node
@register_acc_op_properties(AccOpProperty.pointwise)
@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
@register_acc_op
def div(*, input, other, rounding_mode=None):
return torch.div(input, other, rounding_mode=rounding_mode)
def div(*, input, other):
return input / other
@register_acc_op_properties(AccOpProperty.pointwise)
@register_acc_op_mapping(op_and_target=("call_function", operator.floordiv))
@register_acc_op
def floor_div(*, input, other):
# This is temp fix because currently operator.floor_div for tensors would
# traslate into torch.floor_divide which would throw an error. After it's
# fixed we can stick to `input // other`.
if isinstance(input, torch.Tensor) or isinstance(other, torch.Tensor):
return torch.div(input, other, rounding_mode="floor")
return input // other
# torch.floor_divide rounds result toward zero, rather than -Inf.
# https://github.com/pytorch/pytorch/issues/43874
@register_acc_op_mapping(op_and_target=("call_function", torch.floor_divide))
@register_acc_op_properties(AccOpProperty.pointwise)
@register_acc_op
def trunc_div(*, input, other):
return torch.div(input, other, rounding_mode="trunc")
@register_acc_op_properties(AccOpProperty.pointwise)

View File

@ -171,7 +171,7 @@ class OpSupports:
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
for arg in node._input_nodes:
for arg in node.all_input_nodes:
# escape dtype check for get_attr node
if arg.op == "get_attr":
continue