mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx2trt] break down div (#71172)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71172 Break down div to smaller ops to make those div ops look like all other elementwise ops. Use operator div ops instead of torch div if possible to avoid converting literal numbers to torch tensor (like in the following). ``` a = 1 b = 2 // `c` would be 0.5 c = a / b // `c` would be torch.tensor([0.5]) c = torch.div(a, b) ``` The problem we saw on shufflenet is that there's size op followed by a div op which results in int64 tensors in acc traced graph (acc tracer turns operator.div to acc_ops.div which uses torch.div). And trt splitter splits out the reshape op that consumes the div op because we have a rule to split out ops that takes in int64 tensors as inputs. Test Plan: Unit tests. Reviewed By: wushirong Differential Revision: D33482231 fbshipit-source-id: 508a171520c4e5b4188cfc5c30c1370ba9db1c55
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6a40bb0fdf
commit
54fe2741a1
@ -12,13 +12,14 @@ from torch.testing._internal.common_utils import run_tests
|
||||
elementwise_ops = [
|
||||
((lambda x, y: x + y), acc_ops.add),
|
||||
((lambda x, y: x - y), acc_ops.sub),
|
||||
# Avoid dividing by 0.
|
||||
((lambda x, y: x / (y + 1.0)), acc_ops.div),
|
||||
((lambda x, y: x // (y + 1.0)), acc_ops.div),
|
||||
((lambda x, y: torch.div(x, y + 1.0, rounding_mode="trunc")), acc_ops.div),
|
||||
((lambda x, y: torch.div(x, y + 1.0, rounding_mode="floor")), acc_ops.div),
|
||||
((lambda x, y: torch.div(x, y + 1.0)), acc_ops.div),
|
||||
((lambda x, y: torch.floor_divide(x, y + 1.0)), acc_ops.div),
|
||||
((lambda x, y: x / y), acc_ops.div),
|
||||
((lambda x, y: x // y), acc_ops.floor_div),
|
||||
((lambda x, y: torch.div(x, y, rounding_mode="trunc")), acc_ops.trunc_div),
|
||||
((lambda x, y: torch.div(x, y, rounding_mode="floor")), acc_ops.floor_div),
|
||||
((lambda x, y: torch.div(x, y)), acc_ops.div),
|
||||
# torch.floor_divide rounds result toward zero, rather than -Inf.
|
||||
# https://github.com/pytorch/pytorch/issues/43874
|
||||
((lambda x, y: torch.floor_divide(x, y)), acc_ops.trunc_div),
|
||||
((lambda x, y: x * y), acc_ops.mul),
|
||||
(torch.pow, acc_ops.pow),
|
||||
]
|
||||
@ -36,7 +37,8 @@ class TestBinaryOpConverters(AccTestCase):
|
||||
return self.orig_op(x, x)
|
||||
|
||||
m = TestModule(orig_op)
|
||||
inputs = [torch.randn(1, 1)]
|
||||
# Avoid dividing by 0.
|
||||
inputs = [torch.rand(1, 1) + 1]
|
||||
self.run_test(m, inputs, expected_ops={expected_op})
|
||||
|
||||
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
|
||||
|
@ -1457,6 +1457,17 @@ class AccTracerTest(unittest.TestCase):
|
||||
def test_torch_mul(self):
|
||||
self._make_acc_op_function_test(acc_ops.mul, lambda x: torch.mul(x, 7))
|
||||
|
||||
def test_div(self):
|
||||
self._make_acc_op_function_test(acc_ops.div, lambda x: torch.div(x, 2))
|
||||
self._make_acc_op_function_test(acc_ops.div, lambda x: x / 2)
|
||||
|
||||
def test_floor_div(self):
|
||||
self._make_acc_op_function_test(acc_ops.floor_div, lambda x: torch.div(x, 2, rounding_mode="floor"))
|
||||
|
||||
def test_trunc_div(self):
|
||||
self._make_acc_op_function_test(acc_ops.trunc_div, lambda x: torch.div(x, 2, rounding_mode="trunc"))
|
||||
self._make_acc_op_function_test(acc_ops.trunc_div, lambda x: torch.floor_divide(x, 2))
|
||||
|
||||
def test_view(self):
|
||||
"""
|
||||
Test that Tensor.view is traced correctly.
|
||||
@ -1912,6 +1923,8 @@ class AccTracerTest(unittest.TestCase):
|
||||
acc_ops.sub,
|
||||
acc_ops.mul,
|
||||
acc_ops.div,
|
||||
acc_ops.floor_div,
|
||||
acc_ops.trunc_div,
|
||||
acc_ops.pow,
|
||||
acc_ops.relu,
|
||||
acc_ops.leaky_relu,
|
||||
|
@ -1161,21 +1161,33 @@ def acc_ops_div(
|
||||
kwargs: Dict[str, Argument],
|
||||
name: str,
|
||||
) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
||||
if kwargs["rounding_mode"] == "trunc":
|
||||
inputs = kwargs["input"]
|
||||
other = kwargs["other"]
|
||||
return trunc_div(inputs, other, network, target, name)
|
||||
elif kwargs["rounding_mode"] == "floor":
|
||||
return add_binary_elementwise_layer(
|
||||
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.FLOOR_DIV, target, name
|
||||
)
|
||||
elif kwargs["rounding_mode"] is None:
|
||||
return add_binary_elementwise_layer(
|
||||
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
|
||||
)
|
||||
else :
|
||||
mode = kwargs["rounding_mode"]
|
||||
raise RuntimeError(f"Div received mode {mode} that is not supported!")
|
||||
return add_binary_elementwise_layer(
|
||||
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.DIV, target, name
|
||||
)
|
||||
|
||||
|
||||
@tensorrt_converter(acc_ops.floor_div)
|
||||
def acc_ops_floor_div(
|
||||
network: TRTNetwork,
|
||||
target: Target,
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
name: str,
|
||||
) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
||||
return add_binary_elementwise_layer(
|
||||
network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.FLOOR_DIV, target, name
|
||||
)
|
||||
|
||||
|
||||
@tensorrt_converter(acc_ops.trunc_div)
|
||||
def acc_ops_trunc_div(
|
||||
network: TRTNetwork,
|
||||
target: Target,
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
name: str,
|
||||
) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
||||
return trunc_div(kwargs["input"], kwargs["other"], network, target, name)
|
||||
|
||||
|
||||
@tensorrt_converter(acc_ops.mul)
|
||||
@ -2106,14 +2118,18 @@ def acc_ops_cumsum(
|
||||
set_layer_name(running_sum, target, f"{name}_running_sum_1")
|
||||
running_sum_tensor = running_sum.get_output(0)
|
||||
|
||||
current_sum = add_binary_elementwise_layer(network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, "sum_1")
|
||||
current_sum = add_binary_elementwise_layer(
|
||||
network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, f"{name}_sum_1"
|
||||
)
|
||||
running_sum.set_input(1, current_sum)
|
||||
|
||||
running_sum = loop.add_recurrence(zero_tensor)
|
||||
set_layer_name(running_sum, target, f"{name}_running_sum_2")
|
||||
running_sum_tensor = running_sum.get_output(0)
|
||||
|
||||
current_sum = add_binary_elementwise_layer(network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, "sum_2")
|
||||
current_sum = add_binary_elementwise_layer(
|
||||
network, data, running_sum_tensor, trt.ElementWiseOperation.SUM, target, f"{name}_sum_2"
|
||||
)
|
||||
running_sum.set_input(1, current_sum)
|
||||
|
||||
loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
|
||||
|
@ -630,24 +630,6 @@ def mul(*, input, other):
|
||||
return input * other
|
||||
|
||||
|
||||
# Torch.floor_divide is announced to be deprecated, consider using torch.div() with 'trunc' or 'floor'
|
||||
# mode instead.
|
||||
# This implementation matches torch.floor_div's behavior, which for negative number the divide result
|
||||
# is round toward zero, rather than -Inf.
|
||||
@register_custom_acc_mapper_fn(
|
||||
op_and_target=("call_function", torch.floor_divide),
|
||||
arg_replacement_tuples=[
|
||||
("input", "input"),
|
||||
("other", "other"),
|
||||
],
|
||||
)
|
||||
@register_custom_acc_mapper_fn(
|
||||
op_and_target=("call_function", operator.floordiv),
|
||||
arg_replacement_tuples=[
|
||||
("input", "input"),
|
||||
("other", "other"),
|
||||
],
|
||||
)
|
||||
@register_custom_acc_mapper_fn(
|
||||
op_and_target=("call_function", torch.div),
|
||||
arg_replacement_tuples=[
|
||||
@ -656,32 +638,47 @@ def mul(*, input, other):
|
||||
("rounding_mode", "rounding_mode", this_arg_is_optional),
|
||||
],
|
||||
)
|
||||
@register_custom_acc_mapper_fn(
|
||||
op_and_target=("call_function", operator.truediv),
|
||||
arg_replacement_tuples=[
|
||||
("input", "input"),
|
||||
("other", "other"),
|
||||
],
|
||||
)
|
||||
def div_mapper(node: torch.fx.Node, mod: torch.fx.GraphModule) -> torch.fx.Node:
|
||||
with node.graph.inserting_before(node):
|
||||
div_kwargs = dict(node.kwargs)
|
||||
if "rounding_mode" not in div_kwargs and node.op == "call_function":
|
||||
div_kwargs["rounding_mode"] = None
|
||||
if node.target is torch.floor_divide:
|
||||
div_kwargs["rounding_mode"] = "trunc"
|
||||
elif node.target is operator.floordiv:
|
||||
div_kwargs["rounding_mode"] = "floor"
|
||||
elif node.target is operator.truediv:
|
||||
div_kwargs["rounding_mode"] = None
|
||||
div_node = node.graph.call_function(div, kwargs=div_kwargs)
|
||||
if "rounding_mode" not in div_kwargs or div_kwargs["rounding_mode"] is None:
|
||||
div_node = node.graph.call_function(div, kwargs={"input": div_kwargs["input"], "other": div_kwargs["other"]})
|
||||
elif div_kwargs["rounding_mode"] == "trunc":
|
||||
div_node = node.graph.call_function(trunc_div, kwargs={"input": div_kwargs["input"], "other": div_kwargs["other"]})
|
||||
elif div_kwargs["rounding_mode"] == "floor":
|
||||
div_node = node.graph.call_function(floor_div, kwargs={"input": div_kwargs["input"], "other": div_kwargs["other"]})
|
||||
else:
|
||||
raise RuntimeError(f"Unhandled div rounding mode {div_kwargs['rounding_mode']}")
|
||||
div_node.meta = node.meta.copy()
|
||||
return div_node
|
||||
|
||||
|
||||
@register_acc_op_properties(AccOpProperty.pointwise)
|
||||
@register_acc_op_mapping(op_and_target=("call_function", operator.truediv))
|
||||
@register_acc_op
|
||||
def div(*, input, other, rounding_mode=None):
|
||||
return torch.div(input, other, rounding_mode=rounding_mode)
|
||||
def div(*, input, other):
|
||||
return input / other
|
||||
|
||||
|
||||
@register_acc_op_properties(AccOpProperty.pointwise)
|
||||
@register_acc_op_mapping(op_and_target=("call_function", operator.floordiv))
|
||||
@register_acc_op
|
||||
def floor_div(*, input, other):
|
||||
# This is temp fix because currently operator.floor_div for tensors would
|
||||
# traslate into torch.floor_divide which would throw an error. After it's
|
||||
# fixed we can stick to `input // other`.
|
||||
if isinstance(input, torch.Tensor) or isinstance(other, torch.Tensor):
|
||||
return torch.div(input, other, rounding_mode="floor")
|
||||
return input // other
|
||||
|
||||
|
||||
# torch.floor_divide rounds result toward zero, rather than -Inf.
|
||||
# https://github.com/pytorch/pytorch/issues/43874
|
||||
@register_acc_op_mapping(op_and_target=("call_function", torch.floor_divide))
|
||||
@register_acc_op_properties(AccOpProperty.pointwise)
|
||||
@register_acc_op
|
||||
def trunc_div(*, input, other):
|
||||
return torch.div(input, other, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_acc_op_properties(AccOpProperty.pointwise)
|
||||
|
@ -171,7 +171,7 @@ class OpSupports:
|
||||
submodules: t.Mapping[str, torch.nn.Module],
|
||||
node: torch.fx.Node,
|
||||
) -> bool:
|
||||
for arg in node._input_nodes:
|
||||
for arg in node.all_input_nodes:
|
||||
# escape dtype check for get_attr node
|
||||
if arg.op == "get_attr":
|
||||
continue
|
||||
|
Reference in New Issue
Block a user